diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 86322616..8165ec95 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -20,4 +20,4 @@ A clear and concise description of what you expected to happen. - Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)] **Additional context** -Add any other context about the problem here. \ No newline at end of file +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/documentation_request.md b/.github/ISSUE_TEMPLATE/documentation_request.md index 9e96105f..c9fa21fa 100644 --- a/.github/ISSUE_TEMPLATE/documentation_request.md +++ b/.github/ISSUE_TEMPLATE/documentation_request.md @@ -32,4 +32,4 @@ A clear and concise description of what documentation you believe it is needed a A clear and concise description of what you want to happen. **Steps taken to search for needed documentation** -List any steps you have taken: \ No newline at end of file +List any steps you have taken: diff --git a/.github/ISSUE_TEMPLATE/submit_question.md b/.github/ISSUE_TEMPLATE/submit_question.md index 743f893f..5aa2a672 100644 --- a/.github/ISSUE_TEMPLATE/submit_question.md +++ b/.github/ISSUE_TEMPLATE/submit_question.md @@ -7,4 +7,4 @@ assignees: '' --- -**What is your question?** \ No newline at end of file +**What is your question?** diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 6510938e..23956a02 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -8,4 +8,4 @@ jobs: steps: - uses: actions/labeler@main with: - repo-token: "${{ secrets.GITHUB_TOKEN }}" \ No newline at end of file + repo-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/new-issues-to-triage-projects.yml b/.github/workflows/new-issues-to-triage-projects.yml index 3049176e..a963cb2f 100644 --- a/.github/workflows/new-issues-to-triage-projects.yml +++ b/.github/workflows/new-issues-to-triage-projects.yml @@ -32,4 +32,4 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass - GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing' \ No newline at end of file + GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing' diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index cb2e7275..8b65da69 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -54,4 +54,4 @@ jobs: exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue" days-before-pr-stale: 90 days-before-pr-close: -1 - operations-per-run: 50 \ No newline at end of file + operations-per-run: 50 diff --git a/CHANGELOG.md b/CHANGELOG.md index d09b4981..2418c845 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # NVIDIA CUTLASS Changelog +## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23) +* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) +* [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) +* Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel +* [Grouped GEMM for Multihead Attention](examples/50_multi_head_attention) +* [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) +* Updates and bugfixes from the community (thanks!) + +* **Deprecation announcement:** CUTLASS plans to deprecate the following: + * Maxwell and Pascal GPU architectures + * Ubuntu 16.04 + * CUDA 10.2 + ## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21) * [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment @@ -37,6 +50,7 @@ * Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads) * Updates and bugfixes from the community (thanks!) + ## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19) * **TF32x3:** emulated single-precision using Tensor Cores diff --git a/CITATION.cff b/CITATION.cff deleted file mode 100644 index ea053e66..00000000 --- a/CITATION.cff +++ /dev/null @@ -1,82 +0,0 @@ -cff-version: 1.2.0 -title: CUTLASS -message: >- - If you use this software, please cite using the - following metadata. -type: software -authors: - - given-names: Andrew - email: akerr@nvidia.com - family-names: Kerr - affiliation: NVIDIA - - given-names: Haicheng - family-names: Wu - affiliation: NVIDIA - email: haichengw@nvidia.com - - given-names: Manish - family-names: Gupta - affiliation: Google - email: manigupta@google.com - - given-names: Dustyn - family-names: Blasig - email: dblasig@nvidia.com - affiliation: NVIDIA - - given-names: Pradeep - family-names: Ramini - email: prramani@nvidia.com - affiliation: NVIDIA - - given-names: Duane - family-names: Merrill - email: dumerrill@nvidia.com - affiliation: NVIDIA - - given-names: Aniket - family-names: Shivam - email: ashivam@nvidia.com - affiliation: NVIDIA - - given-names: Piotr - family-names: Majcher - email: pmajcher@nvidia.com - affiliation: NVIDIA - - given-names: Paul - family-names: Springer - email: pspringer@nvidia.com - affiliation: NVIDIA - - given-names: Markus - family-names: Hohnerbach - affiliation: NVIDIA - email: mhohnerbach@nvidia.com - - given-names: Jin - family-names: Wang - email: jinw@nvidia.com - affiliation: NVIDIA - - given-names: Matt - family-names: Nicely - email: mnicely@nvidia.com - affiliation: NVIDIA -repository-code: 'https://github.com/NVIDIA/cutlass' -abstract: >- - CUTLASS is a collection of CUDA C++ template - abstractions for implementing high-performance - matrix-multiplication (GEMM) and related - computations at all levels and scales within CUDA. - It incorporates strategies for hierarchical - decomposition and data movement similar to those - used to implement cuBLAS and cuDNN. CUTLASS - decomposes these "moving parts" into reusable, - modular software components abstracted by C++ - template classes. These thread-wide, warp-wide, - block-wide, and device-wide primitives can be - specialized and tuned via custom tiling sizes, data - types, and other algorithmic policy. The resulting - flexibility simplifies their use as building blocks - within custom kernels and applications. -keywords: - - 'cutlass, tensor cores, cuda' -license: BSD-3-Clause -license-url: https://github.com/NVIDIA/cutlass/blob/v2.9.0/LICENSE.txt -version: '2.9' -date-released: '2022-04-27' -identifiers: - - type: url - value: "https://github.com/NVIDIA/cutlass/tree/v2.9.0" - description: The GitHub release URL of tag 2.9.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index cfed600b..30e261c2 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -project(CUTLASS VERSION 2.9.0 LANGUAGES CXX) +project(CUTLASS VERSION 2.10.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 10.2) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index dccfbda6..576f5ae1 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -11,12 +11,19 @@ Andrew Kerr Haicheng Wu Manish Gupta Dustyn Blasig -Pradeep Ramani +Pradeep Ramani +Cris Cecka +Vijay Thakkar +Aniket Shivam +Honghao Lu +Ethan Yan +Zhaodong Chen +Jack Kosaian +Yujia Zhai Naila Farooqui Piotr Majcher Paul Springer Jin Wang -Aniket Shivam Chinmay Talegaonkar Shang Zhang Scott Yokim @@ -53,7 +60,6 @@ Nick Zhao ## ACKNOWLEDGEMENTS Girish Bharambe -Cris Cecka Luke Durant Olivier Giroux Stephen Jones diff --git a/README.md b/README.md index 78ca725c..c8e29f40 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.9 +# CUTLASS 2.10 -_CUTLASS 2.9 - April 2022_ +_CUTLASS 2.10 - August 2022_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) and related computations at all levels @@ -18,7 +18,9 @@ To support a wide variety of applications, CUTLASS provides extensive support fo mixed-precision computations, providing specialized data-movement and multiply-accumulate abstractions for half-precision floating point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32), -single-precision floating point (FP32), double-precision floating +single-precision floating point (FP32), +[FP32 emulation via tensor core instruction](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm), +double-precision floating point (FP64) types, integer data types (4b and 8b), and binary data types (1b). CUTLASS demonstrates warp-synchronous matrix multiply operations targeting the programmable, high-throughput _Tensor Cores_ implemented by @@ -34,26 +36,14 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. See the [functionality listing](/media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. -# What's New in CUTLASS 2.9 +# What's New in CUTLASS 2.10 -CUTLASS 2.9 is an update to CUTLASS adding: -- [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment -- [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores - - [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu), [HERK](/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu), - - [SYR2K](/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu), [HER2K](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu), - - [Out-of-place TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu), and - - [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu), [HEMM](/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu) -- [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python) -- [GEMM + Softmax example](/examples/35_gemm_softmax) -- [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel. -- [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. Bias Vector add is also supported in the first GEMM/CONV. -- [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation. -- [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC. -- [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores. -- Epilogue enhancement with performance improvement, more activation functions, and more fusion patterns. -- [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix. -- Optimal performance using [CUDA 11.7](https://developer.nvidia.com/cuda-downloads) -- [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler. +CUTLASS 2.10 is an update to CUTLASS adding: +- [Grouped convolution targeting implicit GEMM](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) +- [Depthwise separable convolution](test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) +- Optimizations for CUTLASS's [Grouped GEMM](examples/24_gemm_grouped/gemm_grouped.cu) kernel +- [Grouped GEMM for Multihead Attention](examples/50_multi_head_attention) +- [GEMM + Layer norm fusion for Ampere](examples/37_gemm_layernorm_gemm_fusion/) - Updates and bugfixes from the community (thanks!) - **Deprecation announcement:** CUTLASS plans to deprecate the following: - Maxwell and Pascal GPU architectures @@ -249,15 +239,15 @@ examples/ 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu - 13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel + 13_fused_two_gemms/ # example demonstrating two GEMMs fused in one kernel 22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores - 31_basic_syrk # example demonstrating Symetric rank-K update + 31_basic_syrk # example demonstrating Symmetric Rank-K update - 32_basic_trmm # + 32_basic_trmm # example demonstrating Triangular Matrix-Matrix multiplication - 33_ampere_3xtf32_tensorop_symm # + 33_ampere_3xtf32_tensorop_symm # example demonstrating Symmetric Matrix-Matrix multiplication with FP32 emulation 35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores diff --git a/examples/12_gemm_bias_relu/gemm_bias_relu.cu b/examples/12_gemm_bias_relu/gemm_bias_relu.cu index 62eb2940..f996d542 100644 --- a/examples/12_gemm_bias_relu/gemm_bias_relu.cu +++ b/examples/12_gemm_bias_relu/gemm_bias_relu.cu @@ -54,12 +54,11 @@ using ElementInputA = cutlass::half_t; // <- data type of elements using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D -// The code section below describes matrix layout of input and output matrices. -// Column Major for Matrix A, B and C. - // Note that if the output is column major, the bias has to be per row. i.e. every row has different bias. // If the output is row major, the bias has to be per column, i.e. every column has different bias. // Below list some other notices: +// +// Note this example only works for ColumnMajor output because // 1) we only have row major epilogue. // 2) we swap A and B if the output is column major then we can still use the // row major epilogue. diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index 66b0dee5..8c1e26f5 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -457,9 +457,13 @@ Result profile_convolution(Options const &options) { ElementInputB(-8), 0); - // Fill tensor C on host with zeros - cutlass::reference::host::TensorFill( - tensor_c.host_view()); + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); // Fill tensor D on host with zeros cutlass::reference::host::TensorFill( @@ -686,7 +690,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu b/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu index 2b6b25c7..cf744ecb 100644 --- a/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu +++ b/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu @@ -290,7 +290,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu index 62450e21..d65eb040 100644 --- a/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu +++ b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu @@ -326,7 +326,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu index 41ea3200..a6958b8f 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu +++ b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu @@ -32,7 +32,7 @@ /** The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere -HMMA 16x8x16 FP16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor +16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor core instructions. Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h @@ -67,9 +67,9 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h // elements using ElementAccumulator = float; // Data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -using ElementOutput = cutlass::half_t; // Data type of elements in output tensor +using ElementInputA = cutlass::bfloat16_t; // Data type of elements in input tensor +using ElementInputB = cutlass::bfloat16_t; // Data type of elements in input tensor +using ElementOutput = cutlass::bfloat16_t; // Data type of elements in output tensor using LayoutInputA = cutlass::layout::ColumnMajor; using LayoutInputB = cutlass::layout::RowMajor; @@ -369,22 +369,22 @@ Result profile(Options const &options) { cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), 1, - ElementInputA(4), - ElementInputA(-4), + ElementInputA(2), + ElementInputA(-2), 0); // <- Fill tensor A on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), 1, - ElementInputB(4), - ElementInputB(-4), + ElementInputB(2), + ElementInputB(-2), 0); // <- Fill tensor B on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), 1, - ElementOutput(4), - ElementOutput(-4), + ElementOutput(2), + ElementOutput(-2), 0); // <- Fill matrix C on host with uniform-distribution random data cutlass::reference::host::TensorFill( tensor_d.host_view()); // <- fill matrix D on host with zeros @@ -612,10 +612,10 @@ Result profile(Options const &options) { if (options.reference_check) { output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n"; - output_workspace << "Reference reduction vector= \n" << tensor_ref_reduction.host_view() << "\n\n"; + output_workspace << "Reference reduction vector = \n" << tensor_ref_reduction.host_view() << "\n\n"; } - output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; + output_workspace << "Computed D = \n" << tensor_d.host_view() << std::endl; output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl; std::cout << "Results written to '" << ss.str() << "'." << std::endl; @@ -699,7 +699,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/24_gemm_grouped/gemm_grouped.cu b/examples/24_gemm_grouped/gemm_grouped.cu index a32c80d7..1000f359 100644 --- a/examples/24_gemm_grouped/gemm_grouped.cu +++ b/examples/24_gemm_grouped/gemm_grouped.cu @@ -66,6 +66,7 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +#include #include #include #include @@ -98,6 +99,7 @@ struct Result { double runtime_ms; + double initialization_time_ms; double gflops; cutlass::Status status; cudaError_t error; @@ -109,11 +111,13 @@ struct Result { Result( double runtime_ms = 0, + double initialization_time_ms = 0, double gflops = 0, cutlass::Status status = cutlass::Status::kSuccess, cudaError_t error = cudaSuccess ): - runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } + runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), + status(status), error(error), passed(true) { } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -134,6 +138,8 @@ struct Options { bool help; bool error; bool reference_check; + bool profile_initialization; + bool sort_problems; std::vector problem_sizes; @@ -155,6 +161,29 @@ struct Options { std::string output_tag; std::ofstream output_file; + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + std::vector scheduler_modes; + + std::unordered_map + str_to_scheduler_mode = { + {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, + {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} + }; + + struct GroupScheduleModeHash { + size_t operator()(GroupScheduleMode m) const { + return static_cast(m); + } + }; + + std::unordered_map + scheduler_mode_to_str = { + {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, + {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} + }; + + std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; + // // Methods // @@ -164,12 +193,15 @@ struct Options { error(false), alignment(8), reference_check(true), + profile_initialization(false), + sort_problems(false), problem_count(15), iterations(20), cuda_streams(0), verbose(false), alpha(1), - beta() + beta(), + scheduler_modes({GroupScheduleMode::kDeviceOnly}) { } // Parses the command line @@ -189,8 +221,35 @@ struct Options { cmd.get_cmd_line_argument("streams", cuda_streams, 0); cmd.get_cmd_line_argument("verbose", verbose, false); cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); + cmd.get_cmd_line_argument("sort-problems", sort_problems, false); cmd.get_cmd_line_argument("benchmark", benchmark_path); + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + scheduler_modes.clear(); + if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { + scheduler_modes = all_scheduler_modes; + } else { + for (std::string precomp_str : scheduler_mode_strs) { + auto it = str_to_scheduler_mode.find(precomp_str); + if (it != str_to_scheduler_mode.end()) { + scheduler_modes.push_back(it->second); + } else if (precomp_str == "all") { + std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; + error = true; + return; + } else { + std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; + error = true; + return; + } + } + } + } + std::string output_path; cmd.get_cmd_line_argument("tag", output_tag); cmd.get_cmd_line_argument("output_file", output_path); @@ -314,6 +373,8 @@ struct Options { /// Post processes the problems void bin_problems() { + problem_bins.clear(); + problem_count = int(problem_sizes.size()); // @@ -340,19 +401,22 @@ struct Options { << " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n" << " in device Global Memory and loaded by the kernel.\n\n" << "Options:\n\n" - << " --help If specified, displays this usage statement.\n\n" - << " --benchmark= Executes a benchmark problem size.\n" - << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" - << " --tag= String tag to prepend to the CSV file.\n" - << " --groups= Number of individual GEMM problems (default: --groups=15)\n" - << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" - << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" - << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" - << " --alpha= Epilogue scalar alpha (real part)\n" - << " --beta= Epilogue scalar beta (real part)\n\n" - << " --iterations= Number of profiling iterations to perform.\n" - << " --reference-check= If true, performs reference check.\n" - << " --verbose= If true, prints problem sizes and batching structure.\n"; + << " --help If specified, displays this usage statement.\n\n" + << " --benchmark= Executes a benchmark problem size.\n" + << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" + << " --tag= String tag to prepend to the CSV file.\n" + << " --groups= Number of individual GEMM problems (default: --groups=15)\n" + << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" + << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n" + << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n" + << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" + << " --sort-problems= If true, sorts problem sizes in descending order of GEMM-K dimension.\n"; out << "\n\nExamples:\n\n" @@ -365,6 +429,12 @@ struct Options { << "# Runs a grouped GEMM that is equivalent to a batched GEMM\n" << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n" + << "# Runs a grouped GEMM with each different scheduler mode\n" + << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n" + << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all --profile-initialization=true\n\n" + << "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n" << "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n" << "#\n" @@ -399,10 +469,9 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -template -class TestbedGrouped { +template +class BaseTestbed { public: - // // Type definitions // @@ -421,8 +490,6 @@ public: using MatrixCoord = typename LayoutC::TensorCoord; -private: - // // Data members // @@ -462,13 +529,7 @@ private: cutlass::DeviceAllocation ptr_C; cutlass::DeviceAllocation ptr_D; -public: - - // - // Methods - // - - TestbedGrouped( + BaseTestbed( Options &options_, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, @@ -481,11 +542,9 @@ public: return options.problem_count; } -private: - /// Helper to initialize a tensor view template - void initialize_tensor_( + void initialize_tensor( Element *ptr, size_t capacity, cutlass::Distribution::Kind dist_kind, @@ -539,65 +598,13 @@ private: } } - /// Verbose printing of problem sizes - void print_problem_sizes_() { - - // Print groups - std::cout << problem_count() << " groups:\n"; - - int32_t idx = 0; - int64_t total_tiles = 0; - - for (auto const & problem : options.problem_sizes) { - - int tiles = - ((problem.m() + Gemm::ThreadblockShape::kM - 1) / Gemm::ThreadblockShape::kM) * - ((problem.n() + Gemm::ThreadblockShape::kN - 1) / Gemm::ThreadblockShape::kN); - - total_tiles += tiles; - - std::cout << " [" << idx << "]: " - << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() - << " (" << tiles << " threadblock tiles)" << "\n"; - - ++idx; - } - - // Print batched GEMM equivalent - size_t bin_idx = 0; - size_t problem_count_check = 0; - std::cout << "\nConventionally executed as " << options.problem_bins.size() << " batched GEMMs:\n"; - for (auto const & bin : options.problem_bins) { - - std::cout << " [" << bin_idx << "]: " - << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() - << ", batch count: " << bin.second.size() << "\n"; - - ++bin_idx; - problem_count_check += bin.second.size(); - } - - if (problem_count_check != problem_count()) { - std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl; - } - } - - /// Initializes data structures - void initialize_() { - - // - // Choose random problem sizes - // - - // construct a few problems of random sizes - srand(seed); - + /// Allocates device-side data + void allocate() { int64_t total_elements_A = 0; int64_t total_elements_B = 0; int64_t total_elements_C = 0; int64_t total_elements_D = 0; - lda_host.resize(problem_count()); ldb_host.resize(problem_count()); ldc_host.resize(problem_count()); @@ -628,14 +635,22 @@ private: total_elements_D += elements_D; } - problem_sizes_device.reset(problem_count()); - problem_sizes_device.copy_from_host(options.problem_sizes.data()); - lda.reset(problem_count()); ldb.reset(problem_count()); ldc.reset(problem_count()); ldd.reset(problem_count()); + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + } + + /// Initializes device-side data + void initialize() { + problem_sizes_device.reset(problem_count()); + problem_sizes_device.copy_from_host(options.problem_sizes.data()); + lda.copy_from_host(lda_host.data()); ldb.copy_from_host(ldb_host.data()); ldc.copy_from_host(ldc_host.data()); @@ -645,11 +660,6 @@ private: // Assign pointers // - block_A.reset(total_elements_A); - block_B.reset(total_elements_B); - block_C.reset(total_elements_C); - block_D.reset(total_elements_D); - std::vector ptr_A_host(problem_count()); std::vector ptr_B_host(problem_count()); std::vector ptr_C_host(problem_count()); @@ -678,16 +688,16 @@ private: // Initialize the problems of the workspace // - initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); - initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); - initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); + initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); + initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); + initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); cutlass::reference::device::BlockFillSequential( - block_D.get(), total_elements_D, ElementC(), ElementC()); + block_D.get(), block_D.size(), ElementC(), ElementC()); } /// Verifies the result is a GEMM - bool verify_() { + bool verify() { bool passed = true; @@ -738,7 +748,7 @@ private: cutlass::TensorView view_D( matrix_D.data(), layout_D, extent_C); cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); - + // Reference check passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); @@ -751,227 +761,62 @@ private: return passed; } +}; + +template +class TestbedBatched : BaseTestbed { public: + TestbedBatched( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} - /// Returns the number of threadblocks to launch if the kernel can run on the target - /// device. Otherwise, returns zero. - int sufficient() const { - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); + void print_problem_sizes() { + std::cout << std::endl; + size_t bin_idx = 0; + size_t problem_count_check = 0; + std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n"; + for (auto const & bin : this->options.problem_bins) { - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); + std::cout << " [" << bin_idx << "]: " + << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() + << ", batch count: " << bin.second.size() << "\n"; + + ++bin_idx; + problem_count_check += bin.second.size(); } - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); + if (problem_count_check != this->problem_count()) { + std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl; } - int occupancy = Gemm::maximum_active_blocks(); - - return properties.multiProcessorCount * occupancy; - + std::cout << std::endl; } - - /// Executes a Grouped GEMM kernel and measures runtime. - Result profile_grouped() { + /// Executes a batched kernel and measures runtime + Result profile() { + std::cout << "Batched GEMM:\n" + << "====================================================" << std::endl; Result result; - - int threadblock_count = sufficient(); - - // Early exit - if (!threadblock_count) { - std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; - return result; - } - - if (options.verbose) { - print_problem_sizes_(); - } - result.passed = false; // Initialize the problem - initialize_(); + this->allocate(); + this->initialize(); - // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); - - // Configure GEMM arguments - typename Gemm::Arguments args( - problem_sizes_device.get(), - problem_count(), - threadblock_count, - epilogue_op, - ptr_A.get(), - ptr_B.get(), - ptr_C.get(), - ptr_D.get(), - lda.get(), - ldb.get(), - ldc.get(), - ldd.get() - ); - - // Initialize the GEMM object - Gemm gemm; - - result.status = gemm.initialize(args); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; - return result; + if (this->options.verbose) { + print_problem_sizes(); } - // Run the grouped GEMM object - result.status = gemm.run(); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; - return result; - } - - // Wait for completion - result.error = cudaDeviceSynchronize(); - - if (result.error != cudaSuccess) { - std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); - return result; - } - - // - // Verify correctness - // - result.passed = true; - - if (options.reference_check) { - result.passed = verify_(); - } - - // - // Warm-up run of the grouped GEMM object - // - result.status = gemm.run(); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; - return result; - } - - // - // Construct events - // - - cudaEvent_t events[2]; - - for (auto & event : events) { - result.error = cudaEventCreate(&event); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; - return -1; - } - } - - // Record an event at the start of a series of GEMM operations - result.error = cudaEventRecord(events[0]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // - // Run profiling loop - // - - for (int iter = 0; iter < options.iterations; ++iter) { - gemm(); - } - - // - // Stop profiling loop - // - - // Record an event when the GEMM operations have been launched. - result.error = cudaEventRecord(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // Measure elapsed runtime - float runtime_ms = 0; - result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // Compute average runtime and GFLOPs. - result.runtime_ms = double(runtime_ms) / double(options.iterations); - result.gflops = options.gflops(result.runtime_ms / 1000.0); - - // - // Cleanup - // - - for (auto event : events) { - (void)cudaEventDestroy(event); - } - - int32_t idx = 0; - int64_t total_tiles = 0; - - for (auto const & problem : options.problem_sizes) { - - int tiles = - ((problem.m() + Gemm::ThreadblockShape::kM - 1) / Gemm::ThreadblockShape::kM) * - ((problem.n() + Gemm::ThreadblockShape::kN - 1) / Gemm::ThreadblockShape::kN); - - total_tiles += tiles; - ++idx; - } - - std::cout << std::endl; - std::cout << "Grouped GEMM (CUTLASS):\n" - << "====================================================" << std::endl; - - std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; - - std::cout << std::endl; - std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; - std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; - - if (options.output_file.good()) { - options.output_file << options.output_tag << ",CUTLASS,grouped," - << problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; - } - - return result; - } - - /// Executes a conventional batched GEMM kernel. - Result profile_batched() { - - Result result; - result.passed = false; - // // Prepare batched GEMM environment // - int32_t effective_streams = (options.cuda_streams ? options.cuda_streams : 1); + int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); // Array of leading dimensions used by batched GEMM calls std::vector bin_problem_sizes; @@ -985,15 +830,15 @@ public: std::vector ptr_B_batched_host; std::vector ptr_C_batched_host; - for (auto const & bin : options.problem_bins) { + for (auto const & bin : this->options.problem_bins) { int first_idx = bin.second.front(); - bin_problem_sizes.push_back(options.problem_sizes.at(first_idx)); + bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx)); bin_count.push_back(int32_t(bin.second.size())); - bin_ldm_A.push_back(static_cast(lda_host.at(first_idx))); - bin_ldm_B.push_back(static_cast(ldb_host.at(first_idx))); - bin_ldm_C.push_back(static_cast(ldc_host.at(first_idx))); + bin_ldm_A.push_back(static_cast(this->lda_host.at(first_idx))); + bin_ldm_B.push_back(static_cast(this->ldb_host.at(first_idx))); + bin_ldm_C.push_back(static_cast(this->ldc_host.at(first_idx))); if (ptr_A_batched_host.size() % 2) { ptr_A_batched_host.push_back(nullptr); @@ -1005,29 +850,29 @@ public: for (int idx : bin.second) { - if (bin_problem_sizes.back() != options.problem_sizes.at(idx)) { + if (bin_problem_sizes.back() != this->options.problem_sizes.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - if (bin_ldm_A.back() != lda_host.at(idx)) { + if (bin_ldm_A.back() != this->lda_host.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - if (bin_ldm_B.back() != ldb_host.at(idx)) { + if (bin_ldm_B.back() != this->ldb_host.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - if (bin_ldm_C.back() != ldc_host.at(idx)) { + if (bin_ldm_C.back() != this->ldc_host.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - ptr_A_batched_host.push_back(block_A.get() + offset_A.at(idx)); - ptr_B_batched_host.push_back(block_B.get() + offset_B.at(idx)); - ptr_C_batched_host.push_back(block_D.get() + offset_C.at(idx)); + ptr_A_batched_host.push_back(this->block_A.get() + this->offset_A.at(idx)); + ptr_B_batched_host.push_back(this->block_B.get() + this->offset_B.at(idx)); + ptr_C_batched_host.push_back(this->block_D.get() + this->offset_C.at(idx)); } } @@ -1048,15 +893,14 @@ public: // Create CUDA streams to maximize concurrency of batched-array GEMM kernels // std::vector cuda_streams; - char const *provider = "CUTLASS"; // // Warmup run // - if (options.cuda_streams) { - for (int i = 0; i < options.cuda_streams; ++i) { + if (this->options.cuda_streams) { + for (int i = 0; i < this->options.cuda_streams; ++i) { cudaStream_t stream; result.error = cudaStreamCreate(&stream); @@ -1074,7 +918,7 @@ public: } // Use 'D' for the in/out workspace - block_D.copy_from_device(block_C.get()); + this->block_D.copy_from_device(this->block_C.get()); for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { @@ -1094,9 +938,9 @@ public: // // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); - typename GemmBatched::Arguments arguments{ + typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kArray, problem, batch_count, @@ -1115,7 +959,7 @@ public: int64_t(ldc) }; - GemmBatched gemm_op; + Gemm gemm_op; cutlass::Status status = gemm_op.initialize(arguments); @@ -1182,7 +1026,7 @@ public: int last_stream_idx = 0; - for (int iter = 0; iter < options.iterations; ++iter) { + for (int iter = 0; iter < this->options.iterations; ++iter) { for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { @@ -1204,9 +1048,9 @@ public: // // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); - typename GemmBatched::Arguments arguments{ + typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kArray, problem, batch_count, @@ -1225,7 +1069,7 @@ public: int64_t(ldc) }; - GemmBatched gemm_op; + Gemm gemm_op; cutlass::Status status = gemm_op.initialize(arguments); @@ -1266,20 +1110,6 @@ public: return result; } - // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[0]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - // Measure elapsed runtime float runtime_ms = 0; result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); @@ -1289,8 +1119,8 @@ public: } // Compute average runtime and GFLOPs. - result.runtime_ms = double(runtime_ms) / double(options.iterations); - result.gflops = options.gflops(result.runtime_ms / 1000.0); + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); // // Cleanup @@ -1306,18 +1136,16 @@ public: } } - std::cout << std::endl; - std::cout << "Batched GEMM:\n" - << "====================================================" << std::endl; - - std::cout << " " << bin_problem_sizes.size() << " batched GEMMs launched" << std::endl; + std::cout << " " << this->options.problem_bins.size() << " batched GEMMs launched" << std::endl; std::cout << std::endl; std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms" << std::endl; std::cout << " " << "Batched GFLOPs: " << result.gflops << std::endl; - if (options.output_file.good()) { - options.output_file << options.output_tag << "," << provider << ",batched," - << problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; + std::string provider = "CUTLASS"; + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << "," << provider << ",batched," + << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; } result.passed = true; @@ -1325,14 +1153,286 @@ public: } }; +template +class TestbedGrouped : BaseTestbed { +public: + TestbedGrouped( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + // Redefine GEMM with different GroupScheduleMode_ + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + typename Gemm_::ElementA, + typename Gemm_::LayoutA, + Gemm_::kTransformA, + Gemm_::kAlignmentA, + typename Gemm_::ElementB, + typename Gemm_::LayoutB, + Gemm_::kTransformB, + Gemm_::kAlignmentB, + typename Gemm_::ElementC, + typename Gemm_::LayoutC, + typename Gemm_::ElementAccumulator, + typename Gemm_::OperatorClass, + typename Gemm_::ArchTag, + typename Gemm_::ThreadblockShape, + typename Gemm_::WarpShape, + typename Gemm_::InstructionShape, + typename Gemm_::EpilogueOutputOp, + typename Gemm_::ThreadblockSwizzle, + Gemm_::kStages, + GroupScheduleMode_>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmGrouped; + + /// Verbose printing of problem sizes + void print_problem_sizes() { + std::cout << std::endl; + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = Gemm::problem_tile_count(problem); + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Sort problems in descending order of problem-K dimension + void sort_problems() { + Gemm::sort_problems(this->options.problem_count, + this->options.problem_sizes.data(), + this->lda_host.data(), + this->ldb_host.data(), + this->ldc_host.data(), + this->ldd_host.data(), + this->offset_A.data(), + this->offset_B.data(), + this->offset_C.data(), + this->offset_D.data()); + } + + /// Executes a grouped kernel and measures runtime + Result profile() { + std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; + + std::cout << std::endl; + std::cout << "Grouped GEMM (CUTLASS) with mode " << sched_mode << ":\n" + << "====================================================" << std::endl; + + Result result; + + int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + this->allocate(); + if (this->options.sort_problems) { + sort_problems(); + } + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // Configure the GEMM arguments + typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + this->problem_sizes_device.get(), + this->problem_count(), + threadblock_count, + epilogue_op, + this->ptr_A.get(), + this->ptr_B.get(), + this->ptr_C.get(), + this->ptr_D.get(), + this->lda.get(), + this->ldb.get(), + this->ldc.get(), + this->ldd.get(), + this->options.problem_sizes.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = gemm.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Run the grouped GEMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (this->options.reference_check) { + result.passed = this->verify(); + } + + // + // Warm-up run of the grouped GEMM object + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + // Optionally profile initialization + if (this->options.profile_initialization) { + // Warm up + gemm.initialize(args, workspace.get()); + + auto start_time = std::chrono::high_resolution_clock::now(); + for (int32_t i = 0; i < this->options.iterations; ++i) { + gemm.initialize(args, workspace.get()); + } + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + duration /= double(this->options.iterations); + result.initialization_time_ms = duration.count(); + } + + int64_t total_tiles = Gemm::group_tile_count(args); + std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; + + std::cout << std::endl; + std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; + if (this->options.profile_initialization) { + std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; + } + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," + << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + std::cout << "\nPassed\n"; + + return result; + } +}; + /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // - // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. - // - cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, 0); @@ -1359,7 +1459,7 @@ int main(int argc, char const **args) { // Options options; - + options.parse(argc, args); if (options.help) { @@ -1373,9 +1473,11 @@ int main(int argc, char const **args) { } // - // Define the Grouped GEMM type + // Define the Grouped and Batched GEMM types // + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; using ElementOutput = cutlass::half_t; using ElementAccumulator = float; @@ -1383,6 +1485,30 @@ int main(int argc, char const **args) { using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; + // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 + using GemmBatched = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, LayoutA, + cutlass::half_t, LayoutB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4 + >; + + // Define a grouped GEMM kernel with all template parameters set except + // for scheduling mode. This will be used as the template for all scheduling + // modes executed. using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< cutlass::half_t, LayoutA, @@ -1407,59 +1533,42 @@ int main(int argc, char const **args) { using GemmGrouped = cutlass::gemm::device::GemmGrouped; - // - // Define a conventional batched GEMM type - // - - // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 - using GemmBatched = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, LayoutA, - cutlass::half_t, LayoutB, - ElementOutput, LayoutC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 4 - >; - // // Profile it // - TestbedGrouped testbed(options); - - if (!testbed.sufficient()) { - std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; - return 0; + TestbedBatched testbed_batched(options); + Result result = testbed_batched.profile(); + if (result.error) { + return 1; } - Result result = testbed.profile_grouped(); - if (!result.passed) { - std::cout << "Profiling CUTLASS grouped GEMM has failed.\n"; - std::cout << "\nFailed\n"; - return -1; + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + for (GroupScheduleMode mode : options.scheduler_modes) { + Result result; + switch (mode) { + case GroupScheduleMode::kDeviceOnly: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + case GroupScheduleMode::kHostPrecompute: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + } + + if (result.error != cudaSuccess) { + return 1; + } + + // Override verbose flag to avoid printing duplicate information for each scheduling mode + options.verbose = false; } - result = testbed.profile_batched(); - if (!result.passed) { - - std::cout << "Profiling batched GEMM has failed.\n"; - std::cout << "\nFailed\n"; - return -1; - } - - std::cout << "\nPassed\n"; - return 0; } diff --git a/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt b/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt index 0bf0c775..4cac74d6 100644 --- a/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt +++ b/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt @@ -34,3 +34,8 @@ cutlass_example_add_executable( ampere_fprop_mainloop_fusion.cu ) +cutlass_example_add_executable( + 25_ampere_3d_fprop_mainloop_fusion + ampere_3d_fprop_mainloop_fusion.cu + ) + diff --git a/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu b/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu new file mode 100644 index 00000000..2f3d36ed --- /dev/null +++ b/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu @@ -0,0 +1,776 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + +This example shows how to fuse per channel scale+bias+relu of the activations +into the 3D fprop mainloop. + +Compared with original 3D fprop kernel, this example has two more vectors, one for +the scale and one for the bias. The length of the vectors is the same as the +activation channel number. This kernel loads the vectors when the associated +activation channels are loaded in the mainloop. Between reading the +activations and scale/bias data from the shared memory and calling tensor core +instructions, scale+bias+relu is computed in the register file. + +This example is customized for Ampere 16816 fp16 tensor core instruction. +Changing to different data types or different tensor core instruction require +source code changing. See +include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more +technical details. + +This example is modified based on 25_ampere_fprop_mainloop_fusion. The command +line is the same. +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/conv/kernel/default_conv3d_fprop_fusion.h" +#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using ElementAccumulator = float; // Data type of accumulator +using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) +using ElementInputA = cutlass::half_t; // Data type of elements in input tensor +using ElementInputB = cutlass::half_t; // Data type of elements in input tensor +using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors +using ElementOutput = float; // Data type of elements in output tensor + +using LayoutInputA = cutlass::layout::TensorNDHWC; +using LayoutInputB = cutlass::layout::TensorNDHWC; +using LayoutInputScaleBias = cutlass::layout::RowMajor; +using LayoutOutput = cutlass::layout::TensorNDHWC; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape + +// This code section describes tile size a warp will compute +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// Number of pipelines you want to use +constexpr int NumStages = 4; + +// This code section describe iterator algorithm selected is Analytic or Optimized +static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue>; // Data type for alpha/beta in linear combination + +using Conv3dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv3dFpropFusion< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementInputScaleBias, LayoutInputScaleBias, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm +>::Kernel; + +using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + cutlass::Tensor5DCoord input_size; + cutlass::Tensor5DCoord filter_size; + cutlass::Coord<3> padding; + cutlass::Coord<3> conv_stride; + cutlass::Coord<3> dilation; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool benchmark; + std::string tag; + + Options(): + help(false), + input_size(1, 32, 32, 32, 32), + filter_size(32, 3, 3, 3, 32), + padding(cutlass::make_Coord(1, 1, 1)), + conv_stride(cutlass::make_Coord(1, 1, 1)), + dilation(cutlass::make_Coord(1, 1, 1)), + reference_check(true), + measure_performance(false), + iterations(20), + save_workspace(false), + alpha(1), + beta(0), + benchmark(false) { } + + // Verify the problem size is compatible with the CUTLASS Convolution implementation. + bool valid() { + + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((input_size.c() % kAlignment) || + (filter_size.n() % kAlignment)) { + + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding[0] != filter_size.d() / 2) || + (padding[1] != filter_size.h() / 2) || + (padding[2] != filter_size.w() / 2)) { + + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update( + cutlass::Tensor5DCoord input_size, + cutlass::Tensor5DCoord filter_size, + cutlass::Coord<3> stride) { + + this->input_size = input_size; + this->filter_size = filter_size; + conv_stride = stride; + + padding[0] = filter_size.d() / 2; + padding[1] = filter_size.h() / 2; + padding[2] = filter_size.w() / 2; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + if (cmd.check_cmd_line_flag("ref-check")) { + reference_check = true; + } + + if (cmd.check_cmd_line_flag("perf-check")) { + measure_performance = true; + } + + if (cmd.check_cmd_line_flag("save-workspace")) { + save_workspace = true; + } + + if (cmd.check_cmd_line_flag("benchmark")) { + benchmark = true; + } + + cmd.get_cmd_line_argument("n", input_size.n()); + cmd.get_cmd_line_argument("d", input_size.d()); + cmd.get_cmd_line_argument("h", input_size.h()); + cmd.get_cmd_line_argument("w", input_size.w()); + cmd.get_cmd_line_argument("c", input_size.c()); + + cmd.get_cmd_line_argument("k", filter_size.n()); + cmd.get_cmd_line_argument("t", filter_size.d()); + cmd.get_cmd_line_argument("r", filter_size.h()); + cmd.get_cmd_line_argument("s", filter_size.w()); + filter_size.c() = input_size.c(); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tag", tag); + + if (filter_size.d() == 3 && filter_size.h() == 3 && filter_size.w() == 3) { + padding = cutlass::make_Coord(1, 1, 1); + } + else { + filter_size.d() = 1; + filter_size.h() = 1; + filter_size.w() = 1; + padding = cutlass::make_Coord(0, 0, 0); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "25_ampere_3d_fprop_mainloop_fusion example\n\n" + << " This example fuses scale+bias+relu of the activations into Ampere's\n" + << " Tensor Core operators on F16 data types to compute\n" + << " forward convolution on tensors of layout NDHWC.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --n Input tensor extent N\n" + << " --d Input tensor extent D\n" + << " --h Input tensor extent H\n" + << " --w Input tensor extent W\n" + << " --c Input tensor extent C\n" + << " --k Filter extent K\n" + << " --t Filter extent T\n" + << " --r Filter extent R\n" + << " --s Filter extent S\n\n" + << " --alpha Epilogue scalar alpha\n" + << " --beta Epilogue scalar beta\n\n" + << " --ref-check If set (true), reference check on the host is computed\n" + << " --perf-check If set (true), performance is measured.\n" + << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" + << " --iterations Number of profiling iterations to perform.\n" + << " --save-workspace If set, workspace is written to a text file.\n" + << " --tag String to replicate across the first column in the results table\n"; + + out << "\n\nExamples:\n\n" + << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=32 --d=96 --h=96 --w=96 --c=64 --k=64 --t=1 --r=1 --s=1\n\n" + << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=1 --d=224 --h=224 --w=224 --c=32 --k=32 --t=3 --r=3 --s=3 --ref-check\n\n" + << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=19 --d=94 --h=96 --w=96 --c=128 --k=128 --t=1 --r=1 --s=1\n\n"; + + return out; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor5DCoord output_size() const { + return cutlass::Tensor5DCoord( + input_size.n(), + (input_size.d() + padding[0] + padding[0] - filter_size.d()) / conv_stride[0] + 1, + (input_size.h() + padding[1] + padding[1] - filter_size.h()) / conv_stride[1] + 1, + (input_size.w() + padding[2] + padding[2] - filter_size.w()) / conv_stride[2] + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of multiply-adds = NPQK * CRS + int64_t fmas = output_size().product() * int64_t(filter_size.d() * filter_size.h() * filter_size.w() * filter_size.c()); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Result { + double runtime_ms; + double gflops; + cutlass::Status status; + cutlass::Status reference_check; + cudaError_t error; + + Result(): + runtime_ms(0), + gflops(0), + status(cutlass::Status::kSuccess), + reference_check(cutlass::Status::kInvalid), + error(cudaSuccess) { } + + static std::ostream & print_header(std::ostream &out, Options const &options) { + + if (!options.tag.empty()) { + out << "Name,"; + } + + out << "Layer,N,D,H,W,C,K,T,R,S,Stride_D,Stride_H,Stride_W,Runtime,GFLOPs"; + + return out; + } + + std::ostream & print(std::ostream &out, int idx, Options const &options) { + + if (!options.tag.empty()) { + out << options.tag << ","; + } + + out + << "conv_" << idx << "," + << options.input_size.n() << "," + << options.input_size.d() << "," + << options.input_size.h() << "," + << options.input_size.w() << "," + << options.input_size.c() << "," + << options.filter_size.n() << "," + << options.filter_size.d() << "," + << options.filter_size.h() << "," + << options.filter_size.w() << "," + << options.conv_stride[0] << "," + << options.conv_stride[1] << "," + << options.conv_stride[2] << "," + << runtime_ms << "," + << gflops; + + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Runs one benchmark +Result profile_convolution(Options const &options) { + + Result result; + + // + // Allocate host-device tensors using the CUTLASS Utilities. + // + + cutlass::HostTensor tensor_a(options.input_size); + cutlass::HostTensor tensor_transformed_a(options.input_size); + cutlass::HostTensor tensor_b(options.filter_size); + cutlass::HostTensor + tensor_a_scale({1, options.input_size.c()}); + cutlass::HostTensor + tensor_a_bias({1, options.input_size.c()}); + cutlass::HostTensor tensor_c(options.output_size()); + cutlass::HostTensor tensor_d(options.output_size()); + cutlass::HostTensor tensor_ref_d(options.output_size()); + + // + // Initialize tensors + // + + // Fill tensor A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(3), + ElementInputA(-4), + 0); + + // Fill scale vector for tensor A on host with uniform-distribution random + // data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a_scale.host_view(), + 1, + ElementInputA(3), + ElementInputA(-4), + 0); + + // Fill bias vector for tensor A on host with uniform-distribution random + // data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a_bias.host_view(), + 1, + ElementInputA(3), + ElementInputA(-4), + 0); + + // Fill tensor B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(7), + ElementInputB(-8), + 0); + + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); + + // Fill tensor D for reference on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_a_scale.sync_device(); + tensor_a_bias.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // + // Define arguments for CUTLASS Convolution + // + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Construct Conv3dProblemSize with user defined output size + cutlass::conv::Conv3dProblemSize problem_size( + options.input_size, + options.filter_size, + options.padding, + options.conv_stride, + options.dilation, + options.output_size(), + mode, + split_k_slices + ); + + typename ImplicitGemmFusion::Arguments arguments{ + problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_a_scale.device_ref(), + tensor_a_bias.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + {options.alpha, options.beta}, + }; + + // + // Initialize CUTLASS Convolution + // + + ImplicitGemmFusion implicit_gemm_fusion_op; + + size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + result.status = implicit_gemm_fusion_op.can_implement(arguments); + CUTLASS_CHECK(result.status); + + result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(result.status); + + // + // Launch initialized CUTLASS kernel + // + result.status = implicit_gemm_fusion_op(); + + CUTLASS_CHECK(result.status); + + // + // Optional reference check + // + + if (options.reference_check) { + std::cout << "Verification on device...\n"; + + // Compute scale + bias + relu in host code + for (int n = 0; n < options.input_size.n(); ++n) { + for (int d = 0; d < options.input_size.d(); ++d) { + for (int h = 0; h < options.input_size.h(); ++h) { + for (int w = 0; w < options.input_size.w(); ++w) { + for (int c = 0; c < options.input_size.c(); ++c) { + tensor_transformed_a.at({n, d, h, w, c}) = std::max( + ElementOutput(0), ElementOutput(tensor_a.at({n, d, h, w, c}) * + tensor_a_scale.at({0, c}) + + tensor_a_bias.at({0, c}))); + } + } + } + } + } + + tensor_transformed_a.sync_device(); + + // Compute with reference implementation + cutlass::reference::device::Conv3dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator, + cutlass::NumericConverter + >( + problem_size, + tensor_transformed_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_ref_d.device_ref(), + options.alpha, + options.beta + ); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + if (!passed) { + result.reference_check = cutlass::Status::kErrorInternal; + std::cout << "ERROR - results miscompared.\n"; + } + else { + result.reference_check = cutlass::Status::kSuccess; + std::cout << "Passed.\n"; + } + } + else { + result.reference_check = cutlass::Status::kInvalid; + } + + if (options.save_workspace) { + + std::stringstream ss; + + ss << "25_ampere_3d_fprop_mainloop_fusion" + << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() + << "_" + << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() + << ".dat"; + + std::ofstream output_workspace(ss.str()); + + output_workspace + << "Input = \n" << tensor_a.host_view() << "\n\n" + << "Filters = \n" << tensor_b.host_view() << "\n\n"; + + if (options.reference_check) { + output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; + } + + output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; + + std::cout << "Results written to '" << ss.str() << "'." << std::endl; + } + + // + // Performance measurement + // + + if (options.measure_performance) { + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + } + + // Record an event at the start of a series of convolution operations. + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Launch a sequence of implicit GEMM operations on the device + for (int iteration = 0; iteration < options.iterations; ++iteration) { + result.status = implicit_gemm_fusion_op(); + CUTLASS_CHECK(result.status); + } + + // Record an event when the convolutions have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Print average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + bool notSupported = false; + + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv3dFprop examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (!(props.major >= 8)) { + std::cerr << "This test must run on SM80 or above.\n"; + notSupported = true; + } + + if (notSupported) { + return 0; + } + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.benchmark) { + // Benchmark several layers + + int batch_sizes[] = {34, 18}; + + struct Benchmark { + int d, h, w, c, k, t, r, s, stride_d, stride_h, stride_w; + } layers[] = { + {56, 56, 56, 64, 256, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 64, 64, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 64, 64, 3, 3, 3, 1, 1, 1}, + {56, 56, 56, 256, 64, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 256, 512, 1, 1, 1, 2, 2, 2}, + {56, 56, 56, 256, 128, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 128, 128, 3, 3, 3, 2, 2, 2}, + {28, 28, 28, 128, 512, 1, 1, 1, 1, 1, 1}, + {28, 28, 28, 512, 128, 1, 1, 1, 1, 1, 1}, + {28, 28, 28, 128, 128, 3, 3, 3, 1, 1, 1}, + {28, 28, 28, 512, 1024, 1, 1, 1, 2, 2, 2}, + {28, 28, 28, 512, 256, 1, 1, 1, 1, 1, 1}, + {28, 28, 28, 256, 256, 3, 3, 3, 2, 2, 2}, + {14, 14, 14, 256, 1024, 1, 1, 1, 1, 1, 1}, + {14, 14, 14, 1024, 256, 1, 1, 1, 1, 1, 1}, + {14, 14, 14, 256, 256, 3, 3, 3, 1, 1, 1}, + {14, 14, 14, 1024, 2048, 1, 1, 1, 2, 2, 2}, + {14, 14, 14, 1024, 512, 1, 1, 1, 1, 1, 1}, + {14, 14, 14, 512, 512, 3, 3, 3, 2, 2, 2}, + { 7, 7, 7, 512, 2048, 1, 1, 1, 1, 1, 1}, + { 7, 7, 7, 2048, 512, 1, 1, 1, 1, 1, 1}, + { 7, 7, 7, 512, 512, 3, 3, 3, 1, 1, 1}, + }; + + Result::print_header(std::cout, options) << std::endl; + + int idx = 1; + + for (auto const &layer : layers) { + for (auto N : batch_sizes) { + options.update({N, layer.d, layer.h, layer.w, layer.c}, + {layer.k, layer.t, layer.r, layer.s, layer.c}, + cutlass::make_Coord(layer.stride_d, layer.stride_h, layer.stride_w)); + + Result result = profile_convolution(options); + result.print(std::cout, idx, options) << std::endl; + } + + ++idx; + } + } + else { + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + Result result = profile_convolution(options); + + Result::print_header(std::cout, options) << std::endl; + result.print(std::cout, 1, options) << std::endl; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu b/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu index fe756fba..4b8af864 100644 --- a/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu +++ b/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu @@ -429,9 +429,13 @@ Result profile_convolution(Options const &options) { ElementInputB(-8), 0); - // Fill tensor C on host with zeros - cutlass::reference::host::TensorFill( - tensor_c.host_view()); + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); // Fill tensor D on host with zeros cutlass::reference::host::TensorFill( @@ -575,7 +579,7 @@ Result profile_convolution(Options const &options) { std::stringstream ss; - ss << "25_ampere_fprop_mainloop_fusion_" + ss << "25_ampere_fprop_mainloop_fusion" << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() @@ -677,8 +681,8 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major == 8 && props.minor == 0)) { - std::cerr << "This test must run on SM80 A100.\n"; + if (!(props.major >= 8)) { + std::cerr << "This test must run on SM80 or above.\n"; notSupported = true; } diff --git a/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu b/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu index 72d7284f..a9ee283b 100644 --- a/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu +++ b/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu @@ -266,8 +266,8 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "26_ampere_fused_wgrad_batch_normalization example\n\n" - << " This example fuses scale+bias+relu from batch norm into Ampere's\n" + out << "26_ampere_wgrad_mainloop_fusion example\n\n" + << " This example fuses scale+bias+relu of the activation into Ampere's\n" << " Tensor Core operators on F16 data types to compute\n" << " backward convolution on tensors of layout NHWC.\n\n" << "Options:\n\n" @@ -289,8 +289,8 @@ struct Options { << " --tag= String to replicate across the first column in the results table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" - << "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; + << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" + << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; return out; } @@ -427,9 +427,13 @@ Result profile_convolution(Options const &options) { ElementInputA(-4), 0); - // Fill tensor C on host with zeros - cutlass::reference::host::TensorFill( - tensor_c.host_view()); + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); // Fill tensor D on host with zeros cutlass::reference::host::TensorFill( diff --git a/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu b/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu index a197e2ef..32271687 100644 --- a/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu +++ b/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu @@ -740,7 +740,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/30_wgrad_split_k/30_wgrad_split_k.cu b/examples/30_wgrad_split_k/30_wgrad_split_k.cu index 5016adf2..0c7f32f4 100644 --- a/examples/30_wgrad_split_k/30_wgrad_split_k.cu +++ b/examples/30_wgrad_split_k/30_wgrad_split_k.cu @@ -703,7 +703,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/34_transposed_conv2d/34_transposed_conv2d.cu b/examples/34_transposed_conv2d/34_transposed_conv2d.cu index d9d878ad..a0d08486 100644 --- a/examples/34_transposed_conv2d/34_transposed_conv2d.cu +++ b/examples/34_transposed_conv2d/34_transposed_conv2d.cu @@ -603,7 +603,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 0d18077e..55b87b3c 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -47,14 +47,17 @@ #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/host/tensor_reduce.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/error_metrics.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/layout/matrix.h" #include "cutlass/epilogue/thread/linear_combination.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -85,18 +88,18 @@ struct Options { float alpha; float beta; bool verification_enabled; - double tolerance; + float tolerance; Options(): help(false), problem_size({16, 24, 64}), - batch_count(1), // As a temporary limitation to the test bench, batch count must be 1. The kernels support arbitrary batching. + batch_count(16), iterations(20), seed(2022), alpha(1), - beta(), + beta(0), verification_enabled(true), - tolerance(0.01) + tolerance(1e-5f) { } bool valid() { @@ -116,6 +119,8 @@ struct Options { cmd.get_cmd_line_argument("n", problem_size.n()); cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch_count", batch_count); + cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); @@ -135,6 +140,7 @@ struct Options { << " --m= GEMM M dimension\n" << " --n= GEMM N dimension\n" << " --k= GEMM K dimension\n" + << " --batch_count= Batch number\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" << " --seed= Random number seed (1*)\n\n" @@ -198,13 +204,22 @@ struct Testbed { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; using ElementC = cutlass::half_t; - using ElementD = cutlass::half_t; using ElementCompute = float; - using ElementSoftmax = cutlass::half_t; + using ElementD = ElementC; + using ElementSoftmax = ElementC; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + static int const kStages = 3; + /// Linear scaling operator using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< ElementC, @@ -218,12 +233,21 @@ struct Testbed { ElementB, LayoutB, ElementC, ElementCompute, - EpilogueFunctorOp + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + kStages >; using ElementNorm = typename GemmSoftmax::ElementNorm; using ElementSum = typename GemmSoftmax::ElementSum; using LayoutC = typename GemmSoftmax::LayoutC; + using LayoutN = typename GemmSoftmax::LayoutN; + using LayoutS = typename GemmSoftmax::LayoutS; + using MatrixCoord = typename LayoutC::TensorCoord; // // Data members @@ -231,20 +255,42 @@ struct Testbed { Options const &options; - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor tensor_N; - cutlass::HostTensor tensor_S; - cutlass::HostTensor tensor_Softmax; - cutlass::HostTensor reference_D; cutlass::HostTensor reference_N; - cutlass::HostTensor reference_Softmax; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_Ref; + cutlass::DeviceAllocation block_Softmax; + cutlass::DeviceAllocation block_Norm; + cutlass::DeviceAllocation block_Sum; int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN; + cutlass::gemm::GemmCoord problem = options.problem_size; + + int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); + int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); + int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + // fixed rowmajor for norm and sum + int64_t ldn = problem.m(); + int64_t lds = ldn; + + int64_t total_elements_A_per_batch = problem.m() * problem.k(); + int64_t total_elements_B_per_batch = problem.k() * problem.n(); + int64_t total_elements_C_per_batch = problem.m() * problem.n(); + int64_t total_elements_D_per_batch = problem.m() * problem.n(); + int64_t total_elements_partial_norm_per_batch = block_num * problem.m(); + + int64_t total_elements_A = total_elements_A_per_batch * options.batch_count; + int64_t total_elements_B = total_elements_B_per_batch * options.batch_count; + int64_t total_elements_C = total_elements_C_per_batch * options.batch_count; + int64_t total_elements_D = total_elements_D_per_batch * options.batch_count; + int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count; + // // Methods // @@ -254,20 +300,7 @@ struct Testbed { ): options(options_) { - - tensor_A.reset({options.problem_size.m(), options.problem_size.k()}); - tensor_B.reset({options.problem_size.k(), options.problem_size.n()}); - - tensor_C.reset({options.problem_size.m(), options.problem_size.n()}); - tensor_D.reset({options.problem_size.m(), options.problem_size.n()}); - - tensor_N.reset({block_num, options.problem_size.m()}); - tensor_S.reset({block_num, options.problem_size.m()}); - tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()}); - - reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false); reference_N.reset({options.problem_size.m(), 1}, false); - reference_Softmax.reset({options.problem_size.m(), options.problem_size.n()}, false); } /// Run @@ -300,11 +333,6 @@ struct Testbed { return disposition; } - // - // Compute the reference - // - compute_reference(); - // // Verify // @@ -334,43 +362,38 @@ struct Testbed { /// Random initialization void initialize() { - cutlass::reference::host::TensorFillRandomUniform( - tensor_A.host_view(), - options.seed, - ElementD(5), - ElementD(-5), - 0 - ); + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_Softmax.reset(total_elements_D); + block_Ref.reset(total_elements_D_per_batch); + block_Norm.reset(total_elements_partial_norm); + block_Sum.reset(total_elements_partial_norm); - cutlass::reference::host::TensorFillRandomUniform( - tensor_B.host_view(), - options.seed + 19, - ElementD(5), - ElementD(-5), - 0 - ); + cutlass::reference::device::BlockFillRandomUniform( + block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0); - cutlass::reference::host::TensorFill( - reference_D.host_view(), - ElementD() - ); + cutlass::reference::device::BlockFillRandomUniform( + block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0); cutlass::reference::host::TensorFill( reference_N.host_view(), ElementNorm() ); - cutlass::reference::host::TensorFill( - reference_Softmax.host_view(), - ElementSoftmax() - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_D.sync_device(); - tensor_N.sync_device(); - tensor_S.sync_device(); - tensor_Softmax.sync_device(); } cutlass::Status execute_device_kernel() { @@ -384,17 +407,24 @@ struct Testbed { GemmSoftmax::Arguments args( options.problem_size, options.batch_count, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D.device_ref(), + {block_A.get(), lda}, + {block_B.get(), ldb}, + {block_C.get(), ldc}, + {block_D.get(), ldc}, { ElementCompute(options.alpha), ElementCompute(options.beta) }, - tensor_N.device_ref(), - tensor_S.device_ref(), - tensor_Softmax.device_ref() + {block_Norm.get(), ldn}, + {block_Sum.get(), lds}, + {block_Softmax.get(), ldc}, + total_elements_A_per_batch, + total_elements_B_per_batch, + total_elements_C_per_batch, + total_elements_D_per_batch, + total_elements_partial_norm_per_batch, + total_elements_partial_norm_per_batch, + total_elements_D_per_batch ); // @@ -415,68 +445,21 @@ struct Testbed { return status; } - /// Reference calculation - void compute_reference() { + template + bool verify_tensor(std::vector vector_Input, \ + std::vector vector_Input_Ref) { - // Compute GEMM - - cutlass::reference::host::GemmComplex( - options.problem_size, - options.alpha, - tensor_A.host_ref(), - cutlass::ComplexTransform::kNone, - tensor_B.host_ref(), - cutlass::ComplexTransform::kNone, - options.beta, - tensor_C.host_ref(), - reference_D.host_ref(), - double() - ); - - // Compute the norm - for (int m = 0; m < options.problem_size.m(); ++m) { - reference_N.at({m, 0}) = reference_D.at({m, 0}); - for (int n = 1; n < options.problem_size.n(); ++n) { - reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n}))); - } - } - - // Compute softmax - for (int m = 0; m < options.problem_size.m(); ++m) { - - float sum = float(); - - for (int n = 0; n < options.problem_size.n(); ++n) { - sum += std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) ); - } - - float inv_sum = float(1.0f / sum); - - for (int n = 0; n < options.problem_size.n(); ++n) { - - reference_Softmax.at({m, n}) = ElementSoftmax( - std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum - ); - } - } - } - - /// Emits all tensor values - void emit_results() { - std::cout << "D = \n" << tensor_D.host_view() << "\n\n"; - std::cout << "N = \n" << tensor_N.host_view() << "\n\n"; - std::cout << "Softmax = \n" << tensor_Softmax.host_view() << "\n\n"; - std::cout << "Reference N = \n" << reference_N.host_view() << "\n\n"; - std::cout << "Reference D = \n" << reference_D.host_view() << "\n\n"; - std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n"; - } - - bool verify_tensor_N(cutlass::HostTensor tensor_N, \ - cutlass::HostTensor reference_N) { - - for (int m = 0; m < options.problem_size.m(); ++m) { - float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0})); - if (fabs(diff) > options.tolerance) { + int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); + float abs_tol = options.tolerance; + float rel_tol = options.tolerance; + + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i)); + float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { + printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); return false; } @@ -488,80 +471,112 @@ struct Testbed { /// Verifies the reference matches bool verify() { - tensor_D.sync_host(); - tensor_N.sync_host(); - tensor_Softmax.sync_host(); + LayoutA layout_A(lda); + LayoutB layout_B(ldb); + LayoutC layout_C(ldc); + LayoutN Layout_N(ldn); + LayoutS Layout_S(lds); - double const kThreshold = options.tolerance; + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; - // Verification checks - set any of these to 'true' to override the verification checks. - bool verified_D = false; - bool verified_N = false; - bool verified_Softmax = false; + for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) { - // Verify softmax output - if (!verified_D) { + cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_C, extent_C); - double norm_diff = cutlass::reference::host::TensorNormDiff( - tensor_D.host_view(), - reference_D.host_view()); + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementCompute + >( + problem, + options.alpha, + view_A, + cutlass::ComplexTransform::kNone, + view_B, + cutlass::ComplexTransform::kNone, + options.beta, + view_C, + view_Ref_device, + ElementCompute(0) + ); - double norm_reference = cutlass::reference::host::TensorNorm( - reference_D.host_view()); + // Copy reference results to host memory for verification + std::vector matrix_D_Ref(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size()); + cutlass::TensorView view_Ref(matrix_D_Ref.data(), layout_C, extent_C); - double rel_error = norm_diff / norm_reference; + std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); + cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); - if (rel_error > kThreshold) { - std::cerr << "\n\nTensor D Relative error: " << rel_error << std::endl; + // Copy computed results to host memory + std::vector matrix_D(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); + + std::vector matrix_Softmax(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); + + // Compute the norm + for (int m = 0; m < options.problem_size.m(); ++m) { + reference_N.at({m, 0}) = view_Ref.ref().at({m, 0}); + for (int n = 1; n < options.problem_size.n(); ++n) { + reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n}))); + } } - else { - verified_D = true; + + // Compute softmax + for (int m = 0; m < options.problem_size.m(); ++m) { + + float sum = float(); + + for (int n = 0; n < options.problem_size.n(); ++n) { + sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); + } + + float inv_sum = float(1.0f / sum); + + for (int n = 0; n < options.problem_size.n(); ++n) { + + view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( + std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum + ); + } } - } - if (!verified_N) { - verified_N = verify_tensor_N(tensor_N, reference_N); - } + // Verification checks - set any of these to 'true' to override the verification checks. + bool verified_D = false; + bool verified_Softmax = false; - if (!verified_Softmax) { - - double norm_diff = cutlass::reference::host::TensorNormDiff( - tensor_Softmax.host_view(), - reference_Softmax.host_view()); - - double norm_reference = cutlass::reference::host::TensorNorm( - reference_Softmax.host_view()); - - double rel_error = norm_diff / norm_reference; - - if (rel_error > kThreshold) { - std::cerr << "\n\nSoftmax Relative error: " << rel_error << std::endl; - } - else { - verified_Softmax = true; - } - } - - if (!verified_D || !verified_N || !verified_Softmax) { - - std::cerr << "Verification check failed for tensor Softmax" << std::endl; - - emit_results(); - - // Summarize which checks failed + // Verify softmax output if (!verified_D) { - std::cerr << "Verification of D tensor failed\n"; - } - - if (!verified_N) { - std::cerr << "Verification of N tensor failed\n"; + verified_D = verify_tensor(matrix_D, matrix_D_Ref); } if (!verified_Softmax) { - std::cerr << "Verification of Softmax tensor failed\n"; + verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref); + } + + if (!verified_D || !verified_Softmax) { + + std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; + + // Summarize which checks failed + if (!verified_D) { + std::cerr << "Verification of D tensor failed\n"; + } + + if (!verified_Softmax) { + std::cerr << "Verification of Softmax tensor failed\n"; + } + + return false; } - return false; } return true; @@ -637,14 +652,17 @@ struct Testbed { int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); - double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9); - double gbytes_per_second = double(bytes) * kIterations / double(elapsed_ms / 1000.0f) / double(1 << 30); + double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9); + double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30); + + double elapsed_ms_per_iter = double(elapsed_ms) / kIterations; std::cout << " Problem: " << options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k() + << ", batch size: " << options.batch_count << std::endl; - std::cout << " Runtime: " << elapsed_ms << " ms\n" << std::endl; + std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl; std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl; diff --git a/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h b/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h index 814de5ae..263ed75a 100644 --- a/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +++ b/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h @@ -29,7 +29,8 @@ * **************************************************************************************************/ /*! \file - \brief GEMM kernel to support the 'epilogue visitor' model for fusion. + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once its usage has been stabilized. For now, it is included in this example to demonstrate @@ -78,6 +79,7 @@ public: using ElementC = typename EpilogueVisitor::ElementOutput; using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; static ComplexTransform const kTransformA = Mma::kTransformA; static ComplexTransform const kTransformB = Mma::kTransformB; @@ -89,6 +91,9 @@ public: using InstructionShape = typename Mma::Policy::Operator::InstructionShape; using ArchTag = typename Mma::ArchTag; + using ElementNorm = typename EpilogueVisitor::ElementNorm; + using ElementSum = typename EpilogueVisitor::ElementSum; + static int const kStages = Mma::kStages; static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; @@ -121,6 +126,11 @@ public: TensorRefA ref_A; TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + + ElementNorm *ptr_Max; + ElementSum *ptr_Sum; int64_t batch_stride_A; int64_t batch_stride_B; @@ -144,6 +154,10 @@ public: int batch_count_, TensorRefA ref_A_, TensorRefB ref_B_, + TensorRefC ref_C_, + TensorRefC ref_D_, + ElementNorm *ptr_Max_, + ElementSum *ptr_Sum_, int64_t batch_stride_A_, int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_ @@ -153,6 +167,10 @@ public: batch_count(batch_count_), ref_A(ref_A_), ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Max(ptr_Max_), + ptr_Sum(ptr_Sum_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), epilogue_visitor(epilogue_visitor_) @@ -174,6 +192,8 @@ public: typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; GemmUniversalMode mode; int batch_count; @@ -181,6 +201,11 @@ public: void * ptr_A; void * ptr_B; + ElementC * ptr_C; + ElementC * ptr_D; + + ElementNorm * ptr_Max; + ElementSum * ptr_Sum; int64_t batch_stride_A; int64_t batch_stride_B; @@ -196,11 +221,17 @@ public: swizzle_log_tile(0), params_A(0), params_B(0), + params_C(0), + params_D(0), batch_count(0), gemm_k_size(0), mode(cutlass::gemm::GemmUniversalMode::kGemm), ptr_A(nullptr), ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), batch_stride_A(0), batch_stride_B(0) { } @@ -213,11 +244,17 @@ public: swizzle_log_tile(0), params_A(args.ref_A.layout()), params_B(args.ref_B.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), mode(args.mode), batch_count(args.batch_count), gemm_k_size(args.problem_size.k()), ptr_A(args.ref_A.data()), ptr_B(args.ref_B.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Max(args.ptr_Max), + ptr_Sum(args.ptr_Sum), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), epilogue_visitor(args.epilogue_visitor) @@ -467,7 +504,14 @@ public: thread_idx, warp_idx, lane_idx, - threadblock_offset); + params.params_C, + params.params_D, + params.ptr_C, + params.ptr_D, + params.ptr_Max, + params.ptr_Sum, + threadblock_offset, + blockIdx.y *params.problem_size.m() ); if (params.mode == GemmUniversalMode::kGemm) { // Indicate which position in a serial reduction the output operator is currently updating diff --git a/examples/35_gemm_softmax/gemm_with_softmax.h b/examples/35_gemm_softmax/gemm_with_softmax.h index 213f8c5a..37f7a746 100644 --- a/examples/35_gemm_softmax/gemm_with_softmax.h +++ b/examples/35_gemm_softmax/gemm_with_softmax.h @@ -49,10 +49,12 @@ #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/kernel/default_gemm_complex.h" #include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" +#include "cutlass/reduction/kernel/reduce_softmax_final.h" ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "epilogue_with_visitor.h" #include "gemm_with_epilogue_visitor.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -209,6 +211,9 @@ private: int idx_m = block_m + thread_m; int idx_n = block_n + thread_n; + int batch_offset_norm = block_batch * params.args.batch_stride_N; + int batch_offset_sum = block_batch * params.args.batch_stride_S; + // Kill off thread if it is outside the row boundary if (params.args.extent.row() <= idx_m) { return; @@ -251,8 +256,8 @@ private: params.args.batch_stride_Soft * block_batch + params.args.ref_Soft.layout()({idx_m, idx_n})); - ElementSum inv_sum = (params.args.ref_S.data())[block_m]; - ElementNorm norm = (params.args.ref_N.data())[block_m]; + ElementSum inv_sum = (params.args.ref_S.data())[block_m + batch_offset_sum]; + ElementNorm norm = (params.args.ref_N.data())[block_m + batch_offset_norm]; // // Loop @@ -281,556 +286,6 @@ private: } }; -template < - typename ElementNorm_, - typename ElementSum_, - typename ElementSoftmaxCompute_, - typename ThreadblockShape_ -> -class ApplyFinalReduction { -public: - - using ElementNorm = ElementNorm_; - using ElementSum = ElementSum_; - using ElementSoftmaxCompute = ElementSoftmaxCompute_; - using ThreadblockShape = ThreadblockShape_; - - using Layout = cutlass::layout::RowMajor; - - using TensorRefN = TensorRef; - using TensorRefSum = TensorRef; - - // - // Arguments - // - - struct Arguments { - - MatrixCoord extent; ///< Extent of D and Softmax matrices - int batch_count; ///< Batch count - TensorRefN ref_N; ///< Norm tensor (input / output) - TensorRefSum ref_Sum; ///< Sum tensor (input / output) - int64_t batch_stride_N; ///< Batch stride for N tensor - int64_t batch_stride_Sum; ///< Batch stride for softmax tensor - - // - // Methods - // - Arguments(): - batch_count(1), - batch_stride_N(0), - batch_stride_Sum(0) - { } - - Arguments( - MatrixCoord extent_, ///< Extent of D and Softmax matrices - int batch_count_, ///< Batch count - TensorRefN ref_N_, ///< Output parameter for N - TensorRefSum ref_Sum_ , ///< Sum - int64_t batch_stride_N_ = 0, - int64_t batch_stride_Sum_ = 0 - ): - extent(extent_), - batch_count(batch_count_), - ref_N(ref_N_), - ref_Sum(ref_Sum_), - batch_stride_N(batch_stride_N_), - batch_stride_Sum(batch_stride_Sum_) - { - - } - }; - - struct SharedStorage { - - - }; - - // - // Params struct - // - - struct Params { - Arguments args; - - // - // Methods - // - Params() { } - - Params(Arguments const &args_): args(args_) { } - }; - -private: - -public: - - CUTLASS_DEVICE - ApplyFinalReduction() { } - - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - apply(params, shared_storage); - } - -private: - - /// Partial reduction - CUTLASS_DEVICE - void apply(Params const ¶ms, SharedStorage &shared_storage) { - - int threadblock_num = (params.args.extent.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; - - int block_batch = blockIdx.z; - - int block_n = blockIdx.x * blockDim.x; - - int thread_n = threadIdx.x; - - int idx_n = block_n + thread_n; - - if (idx_n >= params.args.extent.row()) { - return; - } - - - using ConvertSumOutput = cutlass::NumericConverter; - using ConvertNormOutput = cutlass::NumericConverter; - - using ConvertSum = cutlass::NumericConverter; - using ConvertNorm = cutlass::NumericConverter; - - ConvertSum convert_sum; - ConvertNorm convert_norm; - - ConvertSumOutput convert_sum_output; - ConvertNormOutput convert_norm_output; - - ElementNorm *access_n = params.args.ref_N.data() + params.args.batch_stride_N * block_batch + idx_n; - ElementSum *access_s = params.args.ref_Sum.data() + params.args.batch_stride_Sum * block_batch + idx_n; - - ElementNorm *access_n_bak = access_n; - ElementSum *access_s_bak = access_s; - - uint32_t float_max_bits = 0xff7fffff; - float min_float = reinterpret_cast(float_max_bits); - - ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); - ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); - ElementNorm fetch_n; - ElementSum fetch_s; - - CUTLASS_PRAGMA_UNROLL - for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { - arch::global_load(fetch_n, access_n, true); - max_val = fast_max(max_val, convert_norm(fetch_n)); - access_n += params.args.extent.row(); - } - - access_n = access_n_bak; - - CUTLASS_PRAGMA_UNROLL - for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { - arch::global_load(fetch_n, access_n, true); - arch::global_load(fetch_s, access_s, true); - sum_val += convert_sum(fetch_s) * fast_exp(convert_norm(fetch_n) - max_val); - access_n += params.args.extent.row(); - access_s += params.args.extent.row(); - } - - ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; - - access_n = access_n_bak; - access_s = access_s_bak; - - access_n[0] = convert_norm_output(max_val); - access_s[0] = convert_sum_output(inv_sum); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ThreadblockShape_, - int ThreadCount, - typename OutputTileIterator_, - typename ElementAccumulator_, - typename ElementNorm_, - typename ElementSum_, - typename ElementSoftmaxCompute_, - typename ElementwiseFunctor_ -> -class EpilogueVisitorBiasMax { -public: - - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; - - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; - - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; - - using ElementNorm = ElementNorm_; - using ElementSum = ElementSum_; - using ElementSoftmaxCompute = ElementSoftmaxCompute_; - - using AccumulatorFragment = Array; - using SoftmaxFragment = Array; - using OutputVector = Array; - using TensorRefD = TensorRef; - - /// Argument structure - struct Arguments { - - typename ElementwiseFunctor::Params elementwise; - TensorRefD ref_C; - TensorRefD ref_D; - ElementNorm *ptr_Max; - ElementSum *ptr_Sum; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Max; - int64_t batch_stride_Sum; - - // - // Methods - // - Arguments(): - ptr_Max(nullptr), - ptr_Sum(nullptr), - batch_stride_C(0), - batch_stride_D(0), - batch_stride_Max(0), - batch_stride_Sum(0) - { - - } - - Arguments( - typename ElementwiseFunctor::Params elementwise_, - TensorRefD ref_C_, - TensorRefD ref_D_, - ElementNorm *ptr_Max_, - ElementSum *ptr_Sum_, - int64_t batch_stride_C_, - int64_t batch_stride_D_, - int64_t batch_stride_Max_, - int64_t batch_stride_Sum_ - ): - elementwise(elementwise_), - ref_C(ref_C_), - ref_D(ref_D_), - ptr_Max(ptr_Max_), - ptr_Sum(ptr_Sum_), - batch_stride_C(batch_stride_C_), - batch_stride_D(batch_stride_D_), - batch_stride_Max(batch_stride_Max_), - batch_stride_Sum(batch_stride_Sum_) - { - - } - }; - - struct Params { - - typename ElementwiseFunctor::Params elementwise; - typename OutputTileIterator::Params params_C; - typename OutputTileIterator::Params params_D; - typename OutputTileIterator::Element *ptr_C; - typename OutputTileIterator::Element *ptr_D; - ElementNorm *ptr_Max; - ElementSum *ptr_Sum; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Max; - int64_t batch_stride_Sum; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params(): - ptr_D(nullptr), - ptr_Max(nullptr), - ptr_Sum(nullptr) - { - - } - - CUTLASS_HOST_DEVICE - Params(Arguments const &args): - elementwise(args.elementwise), - params_C(args.ref_C.layout()), - params_D(args.ref_D.layout()), - ptr_C(args.ref_C.data()), - ptr_D(args.ref_D.data()), - ptr_Max(args.ptr_Max), - ptr_Sum(args.ptr_Sum), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), - batch_stride_Max(args.batch_stride_Max), - batch_stride_Sum(args.batch_stride_Sum) - { - - } - }; - - /// Shared storage - struct SharedStorage { - - }; - -private: - - Params const & params_; - SharedStorage & shared_storage_; - MatrixCoord extent_; - ElementwiseFunctor elementwise_; - - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; - - ElementAccumulator alpha_; - ElementAccumulator beta_; - - ElementSoftmaxCompute accum_max_; - int threadblock_row_; - -public: - - CUTLASS_DEVICE - EpilogueVisitorBiasMax( - Params const ¶ms, ///< Parameters routed to the epilogue - SharedStorage &shared_storage, ///< Shared storage needed by the functors here - MatrixCoord const &problem_size, ///< Problem size of the output - int thread_idx, ///< Thread index within the threadblock - int warp_idx, ///< Warp index within the threadblock - int lane_idx, ///< Lane index within the warp - MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) - ): - params_(params), - shared_storage_(shared_storage), - extent_(problem_size), - elementwise_(params.elementwise), - iterator_C_(params.params_C, params.ptr_C, problem_size, thread_idx, threadblock_offset), - iterator_D_(params.params_D, params.ptr_D, problem_size, thread_idx, threadblock_offset), - threadblock_row_(threadblock_offset.row()) - { - alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) { - iterator_C_.clear_mask(); - } - } - - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition( - int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) { ///< Total number of split-K slices - - } - - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) { - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); - } - - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() { - - } - - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) { - fragment_D_.clear(); - fragment_C_.clear(); - - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } - - } - - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) { - - } - - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit( - int row_idx, - int column_idx, - int frag_idx, - AccumulatorFragment const &accum) { - - using Mul = cutlass::multiplies; - using Minus = cutlass::minus; - using Exp = cutlass::fast_exp_op; - - Minus minus; - Exp exponential; - - SoftmaxFragment result; - - using ConvertSumOutput = cutlass::NumericConverter; - using ConvertNormOutput = cutlass::NumericConverter; - - ConvertSumOutput convert_sum_output; - ConvertNormOutput convert_norm_output; - - NumericArrayConverter source_converter; - OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; - - if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { - result = source_converter(elementwise_(accum)); - }else{ - result = source_converter(elementwise_(accum, source_vector)); - } - - MatrixCoord thread_offset = - iterator_D_.thread_start() + - OutputTileIterator::ThreadMap::iteration_offset(frag_idx); - - int thread_in_row = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; - int half_thread_in_row = (thread_in_row >> 1); - - bool column_guard = (thread_offset.column() < extent_.column()); - - // Compute the maximum within one row - if (!column_idx) { - // This is the first fragment in a new row - if (column_guard) { - accum_max_ = maximum_accumulator_(result); - } - } - else { - // This is an additional fragment in the same row - if (column_guard) { - accum_max_ = maximum_accumulator_(result, accum_max_); - } - } - - CUTLASS_PRAGMA_UNROLL - for (int i = half_thread_in_row; i > 0; i >>= 1) { - ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, accum_max_, i); - accum_max_ = fast_max(accum_max_, tmp); - } - - SoftmaxFragment sum_frag = exponential(minus(result, accum_max_)); - - ElementSoftmaxCompute reduction_sum = sum_accumulator_(sum_frag); - - CUTLASS_PRAGMA_UNROLL - for (int i = half_thread_in_row; i > 0; i >>= 1) { - ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, reduction_sum, i); - reduction_sum += tmp; - } - - bool is_write_thread = (thread_offset.row() < extent_.row() && (threadIdx.x % thread_in_row) == 0); - ElementNorm *curr_ptr_max = params_.ptr_Max + thread_offset.row() + blockIdx.y * extent_.row(); - ElementSum *curr_ptr_sum = params_.ptr_Sum + thread_offset.row() + blockIdx.y * extent_.row(); - - arch::global_store( - convert_norm_output(accum_max_), - (void *)curr_ptr_max, - is_write_thread); - - arch::global_store( - convert_sum_output(reduction_sum), - (void *)curr_ptr_sum, - is_write_thread); - - clear_accum_max_(); - - // Convert to the output - NumericArrayConverter output_converter; - OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); - } - - /// Called at the start of a row - CUTLASS_DEVICE - void end_row(int row_idx) { - - } - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) { - - iterator_D_.store(fragment_D_); - ++iterator_D_; - } - - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() { - - } - -private: - - CUTLASS_DEVICE - void clear_accum_max_() { - - uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX - float min_float = reinterpret_cast(float_max_bits); - accum_max_ = ElementSoftmaxCompute(min_float); - } - - CUTLASS_DEVICE - ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) { - ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SoftmaxFragment::kElements; ++i) { - sum_ += ElementSoftmaxCompute(accum[i]); - } - - return sum_; - } - - CUTLASS_DEVICE - ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) { - ElementSoftmaxCompute max_ = accum[0]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < SoftmaxFragment::kElements; ++i) { - max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); - } - - return max_; - } - - CUTLASS_DEVICE - ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SoftmaxFragment::kElements; ++i) { - max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); - } - - return max_; - } -}; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -846,10 +301,18 @@ template < typename LayoutB_, typename ElementC_, typename ElementCompute_, + typename OperatorClass_, + typename ArchTag_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, typename EpilogueFunctorOp_, + int kStages_, + int AlignmentA_ = 128 / cutlass::sizeof_bits::value, + int AlignmentB_ = 128 / cutlass::sizeof_bits::value, + int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits::value, typename ElementNorm_ = float, typename ElementSum_ = float, - int Alignment = 128 / cutlass::sizeof_bits::value, typename ElementSoftmax_ = ElementC_ > class GemmSoftmax { @@ -872,8 +335,6 @@ public: using LayoutA = LayoutA_; using LayoutB = LayoutB_; - static int const kAlignment = Alignment; - using EpilogueFunctorOp = EpilogueFunctorOp_; using ElementNorm = ElementNorm_; @@ -890,13 +351,17 @@ public: using TensorRefSum = TensorRef; using TensorRefSoft = TensorRef; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; - using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using ArchTag = cutlass::arch::Sm80; - static int const kStages = 3; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + + static int const kStages = kStages_; + static int const AlignmentA = AlignmentA_; + static int const AlignmentB = AlignmentB_; + static int const AlignmentSoftmax = AlignmentSoftmax_; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; @@ -906,10 +371,10 @@ public: using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, - kAlignment, + AlignmentA, ElementB, LayoutB, - kAlignment, + AlignmentB, ElementC, LayoutC, ElementCompute, @@ -930,7 +395,7 @@ public: /////////////////////////////////////////////////////////////////////////////////////////////// // Epilogue visitor - using EpilogueVisitor = kernel::EpilogueVisitorBiasMax< + using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< ThreadblockShape, DefaultGemmKernel::kThreadCount, typename DefaultGemmKernel::Epilogue::OutputTileIterator, @@ -961,13 +426,13 @@ public: ElementSum, ElementSoft, ElementSoftmaxCompute, - kAlignment, + AlignmentSoftmax, MatrixShape< 1, 1024 > >; - using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< + using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< ElementNorm, ElementSum, ElementSoftmaxCompute, @@ -983,6 +448,7 @@ public: typename SoftmaxApplyKernel::Arguments softmax; typename ApplyFinalReductionKernel::Arguments reduction; cutlass::gemm::GemmCoord extend; + // // Methods // @@ -1013,14 +479,14 @@ public: batch_count_, ref_A_, ref_B_, + ref_C_, + ref_D_, + ref_N_.data(), + ref_S_.data(), batch_stride_A_, batch_stride_B_, typename EpilogueVisitor::Arguments( linear_scaling, - ref_C_, - ref_D_, - ref_N_.data(), - ref_S_.data(), batch_stride_C_, batch_stride_D_, batch_stride_Max_, @@ -1028,10 +494,9 @@ public: ) ), reduction( - MatrixCoord(problem_size.m(), problem_size.n()), - batch_count_, - ref_N_, - ref_S_, + problem_size, + ref_N_.data(), + ref_S_.data(), batch_stride_Max_, batch_stride_Sum_ ), @@ -1127,28 +592,24 @@ public: // Launch the ApplyFinalReductionKernel // - int threadblock_num_in_column = (params_.extend.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + int thread_per_block = 128; + int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + } - if (threadblock_num_in_column > 1) { - int thread_per_block = 128; - int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; - if (block_per_row < 4) { - thread_per_block = 32; - block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; - } + dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count); + dim3 final_reduction_block(thread_per_block); - dim3 final_reduction_grid(block_per_row); - dim3 final_reduction_block(thread_per_block); + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); - Kernel<<< - final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream - >>>(params_.reduction); + result = cudaGetLastError(); - result = cudaGetLastError(); - - if (result != cudaSuccess) { - return cutlass::Status::kErrorInternal; - } + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; } // diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index f8fbcc33..afab8344 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -40,18 +40,17 @@ // for (int j = 0; j < options.index_size; ++j) { // int b_c_d_col = tensor_indices.at({j, 0}); // -// for (int k = 0; k < problem_size.k(); ++k) { +// for (int k = 0; k < options.index_size; ++k) { // tensor_d_ref.at({i, b_c_d_col}) += // alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); // } // } -// } // // Note that the index vector contains unique random integers with max to be N - 1 // // The gather/scatter operation works best when we can still keep the biggest // alignment. For example, when the matrix is row major, we select rows. When -// the matrix is column major, we selct columns. +// the matrix is column major, we select columns. // // Not all the combination of gather and scatter are legal. For example, if A is // row major and C/D is column major, we cannot gather A and scatter C/D at the @@ -257,7 +256,7 @@ using Gemm = cutlass::gemm::device::GemmUniversal; @@ -353,7 +352,7 @@ int run(Options &options) { tensor_b.layout().stride(), tensor_c.layout().stride(), tensor_d_scattered.layout().stride(), - nullptr, // <- pointer to index vector to gather A on device + nullptr, // <- pointer to index vector to gather A on device tensor_indices.device_data(), // <- pointer to index vector to gather B on device tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device @@ -392,7 +391,7 @@ int run(Options &options) { tensor_d_ref.at({i, b_c_d_col}) += alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); } - + tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col})); } } @@ -515,7 +514,7 @@ int main(int argc, const char ** argv) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt b/examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt new file mode 100644 index 00000000..af50bb11 --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 37_gemm_layernorm_gemm_fusion + gemm_layernorm.cu + ) + diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu b/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu new file mode 100644 index 00000000..ca33d44e --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu @@ -0,0 +1,937 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Layernorm Example. + + This workload provides a layer normalization example using a one-pass, square-sum-based + variance calculation. Specifically, we fuse the reduction operation to find + local mean and local square sum mean in the epilogue of 1st GEMM. After a light + full reduction kernel, the mean / variance values are readily calculated for element-wise + operations which are fused into the 2nd GEMM. + + As stated in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data, + the square-sum based one-pass implementation may raise concerns on numerical stability issues. + That being said, though this fully fused layernorm example almost perfectly hides all the memory cost to + access the intermediate matrix for layernorm computation, the numerical issue might hinder a persuasive + usage in real-world scenarios. If that is the case, a user may turn to the stand-alone CUTLASS layernorm + example in tools/util/include/cutlass/util/device_layernorm.h + + Examples: + + # Run a CUTLASS layernorm example with default setup , + # using the language of the transformer model as an example, + (Column Major output matrix, hidden dimension = 768, valid word number = 4096, intermediate_scale = 4) + $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion + + # Run an attention example with hidden dimension = 512 + $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion --hidden_dim=512 + +*/ + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/epilogue/thread/scale_type.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/fast_math.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_layernorm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class Disposition { + kPassed, + kIncorrect, + kNotVerified +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +template +struct Options { + + using LayoutOutput = LayoutOutput_; + + static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; + + bool help; + cutlass::gemm::GemmCoord problem_size0; + cutlass::gemm::GemmCoord problem_size1; + int hidden_dim; + int valid_word_num; + int intermediate_scale; + int iterations; + unsigned seed; + float alpha; + float beta; + bool verification_enabled; + double tolerance; + + Options(): + help(false), + iterations(20), + seed(2022), + hidden_dim(768), + valid_word_num(4096), + intermediate_scale(4), + alpha(1), + beta(0), + verification_enabled(true), + tolerance(0.01), + problem_size1(problem_size0.m() * 4, problem_size0.n(), problem_size0.m()) + { } + + bool valid() { + + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("hidden_dim", hidden_dim, 768); + cmd.get_cmd_line_argument("valid_word_num", valid_word_num, 4096); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("verify", verification_enabled); + cmd.get_cmd_line_argument("seed", seed); + cmd.get_cmd_line_argument("tolerance", tolerance); + + if (kIsColumnMajorOutput) { + // column major output setup + problem_size0.m() = hidden_dim; + problem_size0.n() = valid_word_num; + problem_size0.k() = hidden_dim; + + problem_size1.m() = hidden_dim * intermediate_scale; + problem_size1.n() = valid_word_num; + problem_size1.k() = hidden_dim; + }else{ + // row major output setup + problem_size0.m() = valid_word_num; + problem_size0.n() = hidden_dim; + problem_size0.k() = hidden_dim; + + problem_size1.m() = valid_word_num; + problem_size1.n() = hidden_dim * intermediate_scale; + problem_size1.k() = hidden_dim; + } + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "37_gemm_layernorm_gemm_fusion example\n\n" + << " This example uses the CUTLASS Library to compute GEMM + Layernorm for arbitrary problem sizes.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --hidden_dim= Hidden dimension\n" + << " --valid_word_num= Valid word number\n" + << " --seed= Random number seed (1*)\n\n" + << " --iterations= Number of profiling iterations to perform (0 to disable profiling).\n\n" + << " --verify= If true, performs reference calculation.\n\n" + << " --tolerance Error tolerance\n" + ; + + out << "\n\nExamples:\n\n" + << "$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion \\\n" + << " --hidden_dim=768 --valid_word_num=1024 \n\n"; + + return out; + } + + /// Returns true if the environment and Toolkit support this + bool supported(bool verbose = true) const { + + // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + if (verbose) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + } + return false; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + if (verbose) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + } + return false; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + if (verbose) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + } + return false; + } + + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((problem_size0.m() % kAlignment) || + (problem_size0.n() % kAlignment) || + (problem_size0.k() % kAlignment)) { + if (verbose) { + std::cerr << "Misaligned input in 1st GEMM." << std::endl; + } + // misaligned tensors for Gemm1 + return false; + } + + if ((problem_size1.m() % kAlignment) || + (problem_size1.n() % kAlignment) || + (problem_size1.k() % kAlignment)) { + if (verbose) { + std::cerr << "Misaligned input in 2nd GEMM." << std::endl; + } + // misaligned tensors for Gemm2 + return false; + } + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + typename LayoutOutput_> +struct Testbed { + + // + // Type definitions + // + + // User-defined data types + using ElementInputA0 = cutlass::half_t; + using ElementInputB0 = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + using LayoutInputA0 = cutlass::layout::RowMajor; + using LayoutInputB0 = cutlass::layout::ColumnMajor; + using LayoutOutput = LayoutOutput_; + + static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; + // turn of shifted K by default + static bool const kIsShiftedVariance = false; + + /// Linear scaling operator + using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementCompute, + ElementCompute + >; + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + static int const kStages0 = 3; + static int const kStages1 = 4; + + using GemmLayernorm = cutlass::GemmLayernorm< + ElementInputA0, + LayoutInputA0, + ElementInputB0, + LayoutInputB0, + ElementOutput, + LayoutOutput, + ElementCompute, + EpilogueFunctorOp, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages0, + kStages1, + kIsShiftedVariance + >; + + using ElementInputA1 = typename GemmLayernorm::ElementInputA1; + using ElementOutputC1 = typename GemmLayernorm::ElementOutputC1; + using ElementInputScaleBias = typename GemmLayernorm::ElementInputScaleBias; + using ElementLayernormCompute = typename GemmLayernorm::ElementLayernormCompute; + + using LayoutInputA1 = typename GemmLayernorm::LayoutInputA1; + using LayoutOutputC0 = typename GemmLayernorm::LayoutOutputC0; + using LayoutOutputC1 = typename GemmLayernorm::LayoutOutputC1; + using LayoutInputScaleBias = typename GemmLayernorm::LayoutInputScaleBias; + + // + // Data members + // + + Options const &options; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_A1; + cutlass::HostTensor tensor_C1; + + cutlass::HostTensor reference_C0; + cutlass::HostTensor reference_C1; + + cutlass::HostTensor tensor_Variance; + cutlass::HostTensor tensor_Mean; + cutlass::HostTensor tensor_Beta; + cutlass::HostTensor tensor_Gamma; + + cutlass::HostTensor reference_Mean; + cutlass::HostTensor reference_Variance; + + // shifted K tensor to better ensure the numerical stability + // According to https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + // the closer shifted K to the actual mean, the better numerical stability we'll observe + cutlass::HostTensor tensor_Shifted_K; + + // + // Methods + // + + Testbed( + Options const &options_ + ): + options(options_) + { + + tensor_A0.reset({options.problem_size0.m(), options.problem_size0.k()}); + tensor_B0.reset({options.problem_size0.k(), options.problem_size0.n()}); + + tensor_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); + + tensor_A1.reset({options.problem_size1.m(), options.problem_size1.k()}); + tensor_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); + + reference_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); + reference_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); + + int leading_dim_0 = kIsColumnMajorOutput ? options.problem_size0.n() : options.problem_size0.m(); + int leading_dim_1 = kIsColumnMajorOutput ? options.problem_size0.m() : options.problem_size0.n(); + + int block_num = (leading_dim_1 + GemmLayernorm::ThreadblockShape::kM - 1) / GemmLayernorm::ThreadblockShape::kM; + + tensor_Variance.reset({block_num, leading_dim_0}); + tensor_Mean.reset({block_num, leading_dim_0}); + tensor_Shifted_K.reset({1, leading_dim_0}); + + tensor_Beta.reset({1, leading_dim_1}); + tensor_Gamma.reset({1, leading_dim_1}); + + reference_Mean.reset({1, leading_dim_0}, false); + reference_Variance.reset({1, leading_dim_0}, false); + + } + + /// Run + Disposition run() { + + Disposition disposition = Disposition::kNotVerified; + + // + // Initialize the workspace + // + + initialize(); + + // + // Launch device kernel + // + cutlass::Status status = cutlass::Status::kSuccess; + + status = execute_device_kernel(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Device execution failed." << std::endl; + return disposition; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Device synchronize failed with error " + << cudaGetErrorString(result) << std::endl; + return disposition; + } + + // + // Compute the reference + // + compute_reference(); + + // + // Verify + // + + if (options.verification_enabled) { + + bool passed = verify(); + + if (passed) { + disposition = Disposition::kPassed; + } + else { + disposition = Disposition::kIncorrect; + } + } + + // + // Profiling + // + if (options.iterations) { + profile(); + } + + return disposition; + } + + /// Random initialization + void initialize() { + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A0.host_view(), + options.seed, + ElementInputA0(5), + ElementInputA0(-5), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B0.host_view(), + options.seed + 1, + ElementInputB0(5), + ElementInputB0(-5), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A1.host_view(), + options.seed + 2, + ElementInputA1(5), + ElementInputA1(-5), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_Beta.host_view(), + options.seed + 3, + ElementInputScaleBias(5), + ElementInputScaleBias(-5), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_Gamma.host_view(), + options.seed + 4, + ElementInputScaleBias(5), + ElementInputScaleBias(-5), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_Shifted_K.host_view(), + options.seed + 5, + ElementOutput(5), + ElementOutput(-6), + 0 + ); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_A1.sync_device(); + tensor_Beta.sync_device(); + tensor_Gamma.sync_device(); + + } + + + + cutlass::Status execute_device_kernel() { + + cutlass::Status status = cutlass::Status::kSuccess; + + // + // Setup arguments + // + + typename GemmLayernorm::Arguments args( + options.problem_size0, + options.problem_size1, + tensor_A0.device_ref().data(), + tensor_B0.device_ref().data(), + tensor_C0.device_ref().data(), + tensor_C0.device_ref().data(), + tensor_A1.device_ref().data(), + tensor_C1.device_ref().data(), + tensor_A0.device_ref().stride(0), + tensor_B0.device_ref().stride(0), + tensor_C0.device_ref().stride(0), + tensor_C0.device_ref().stride(0), + tensor_A1.device_ref().stride(0), + tensor_C1.device_ref().stride(0), + { + ElementCompute(options.alpha), + ElementCompute(options.beta) + }, + tensor_Variance.device_ref(), + tensor_Mean.device_ref(), + tensor_Gamma.device_ref(), + tensor_Beta.device_ref(), + tensor_Shifted_K.device_ref().data() + ); + + // + // Launch + // + + GemmLayernorm gemm_layernorm; + + // Initialize + status = gemm_layernorm.initialize(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run + status = gemm_layernorm(); + + return status; + } + + /// Reference calculation + void compute_reference() { + + cutlass::reference::device::Gemm< + ElementInputA0, + LayoutInputA0, + ElementInputB0, + LayoutInputB0, + ElementOutput, + LayoutOutputC0, + ElementCompute, + ElementCompute + > gemm_device0; + + cutlass::reference::device::Gemm< + ElementInputA1, + LayoutInputA1, + ElementOutput, + LayoutOutputC0, + ElementOutputC1, + LayoutOutputC1, + ElementCompute, + ElementCompute + > gemm_device1; + + // Compute 1st GEMM + gemm_device0( + options.problem_size0, + ElementCompute(options.alpha), + tensor_A0.device_ref(), + tensor_B0.device_ref(), + ElementCompute(options.beta), + tensor_C0.device_ref(), + reference_C0.device_ref() + ); + + reference_C0.sync_host(); + + tensor_Mean.sync_host(); + tensor_Variance.sync_host(); + tensor_Gamma.sync_host(); + tensor_Beta.sync_host(); + tensor_Shifted_K.sync_host(); + + // Compute the sum and square sum for verification purpose + if (kIsColumnMajorOutput) { + for (int n = 0; n < options.problem_size0.n(); ++n) { + + ElementLayernormCompute sum = ElementLayernormCompute(0); + ElementLayernormCompute square_sum = ElementLayernormCompute(0); + for (int m = 0; m < options.problem_size0.m(); ++m) { + sum += ElementLayernormCompute(reference_C0.at({m, n})); + square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})); + } + + ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.m()); + ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.m()); + ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6) ) ; + + mean = -mean * variance; + + reference_Mean.at({0, n}) = ElementInputScaleBias(mean); + reference_Variance.at({0, n}) = ElementInputScaleBias(variance); + } + }else{ + for (int m = 0; m < options.problem_size0.m(); ++m) { + + ElementLayernormCompute sum = ElementLayernormCompute(0); + ElementLayernormCompute square_sum = ElementLayernormCompute(0); + for (int n = 0; n < options.problem_size0.n(); ++n) { + sum += ElementLayernormCompute(reference_C0.at({m, n})) ; + square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})) ; + } + + ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.n()); + ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.n()); + ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)) ; + + mean = -mean * variance; + + reference_Mean.at({0, m}) = ElementInputScaleBias(mean); + reference_Variance.at({0, m}) = ElementInputScaleBias(variance); + } + } + + // Element-wise transform for OutputC0 using 1-pass layernorm algo + if (kIsColumnMajorOutput) { + for (int n = 0; n < options.problem_size0.n(); ++n) { + + ElementLayernormCompute sum = ElementLayernormCompute(0); + for (int m = 0; m < options.problem_size0.m(); ++m) { + sum += ElementLayernormCompute(reference_C0.at({m, n})) ; + } + + ElementInputScaleBias mean = ElementInputScaleBias(sum / ElementLayernormCompute(options.problem_size0.m())); + sum = ElementLayernormCompute(0); + for (int m = 0; m < options.problem_size0.m(); ++m) { + sum += ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) * ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) ; + } + + ElementLayernormCompute square_mean = sum / ElementLayernormCompute(options.problem_size0.m()); + ElementInputScaleBias variance = ElementInputScaleBias(cutlass::constants::one() + / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6))) ; + + for (int m = 0; m < options.problem_size0.m(); ++m) { + reference_C0.at({m, n}) = + ElementOutput( ( (ElementInputScaleBias(reference_C0.at({m, n})) - mean) * variance ) + * tensor_Gamma.at({0, m}) + tensor_Beta.at({0, m})); + + } + + } + }else{ + + for (int m = 0; m < options.problem_size0.m(); ++m) { + + float sum = float(0); + for (int n = 0; n < options.problem_size0.n(); ++n) { + sum += float(reference_C0.at({m, n})) ; + } + + float mean = sum / float(options.problem_size0.n()); + sum = float(0); + for (int n = 0; n < options.problem_size0.n(); ++n) { + sum += float(reference_C0.at({m, n}) - mean) * float(reference_C0.at({m, n}) - mean) ; + } + + float square_mean = sum / float(options.problem_size0.n()); + float variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6)) ; + + for (int n = 0; n < options.problem_size0.n(); ++n) { + reference_C0.at({m, n}) = + ElementOutput( ( (float(reference_C0.at({m, n})) - mean) * variance ) + * float(tensor_Gamma.at({0, n})) + float(tensor_Beta.at({0, n}))); + + } + + } + + } + + + // Sync host data with device after element-wise transform + reference_C0.sync_device(); + + // Compute 2nd GEMM + gemm_device1( + options.problem_size1, + ElementCompute(options.alpha), + kIsColumnMajorOutput ? tensor_A1.device_ref() : reference_C0.device_ref(), + kIsColumnMajorOutput ? reference_C0.device_ref() :tensor_A1.device_ref(), + ElementCompute(options.beta), + reference_C1.device_ref(), + reference_C1.device_ref() + ); + + } + + /// Emits all tensor values + void emit_results() { + std::cout << "tensor_C1 = \n" << tensor_C1.host_view() << "\n\n"; + std::cout << "Reference C1 = \n" << reference_C1.host_view() << "\n\n"; + std::cout << "Mean = \n" << tensor_Mean.host_view() << "\n\n"; + std::cout << "rsqrt(Variance) = \n" << tensor_Variance.host_view() << "\n\n"; + std::cout << "Reference Mean = \n" << reference_Mean.host_view() << "\n\n"; + std::cout << "Reference rsqrt(Variance) = \n" << reference_Variance.host_view() << "\n\n"; + } + + template + bool verify_tensor(cutlass::HostTensor tensor, \ + cutlass::HostTensor reference, + int leading_dim0, int leading_dim1, bool is_print = false) { + float const kThreshold = float(options.tolerance); + float const kAbsThreshold = 0.5f; + float const kRelativeThreshold = 0.1f; + // Adds a constant bias to avoid being divided by '0' + float const kBias = 1e-5f; + int counter = 0; + for (int m = 0; m < leading_dim0; m++) { + for (int n = 0; n < leading_dim1; ++n) { + float diff = (float)(tensor.at({m, n}) - reference.at({m, n})); + float rel_diff = fabs(diff) / fabs(reference.at({m, n}) + kBias); + if (fabs(diff) > kAbsThreshold && rel_diff > kRelativeThreshold) { + counter++; + } + } + } + + float err_rate = float(counter) / (float(leading_dim0) * float(leading_dim1)); + return (err_rate < kThreshold); + } + + /// Verifies the reference matches + bool verify() { + + tensor_Variance.sync_host(); + tensor_Mean.sync_host(); + tensor_C1.sync_host(); + reference_C1.sync_host(); + + // Verification checks - set any of these to 'true' to override the verification checks. + bool verified_C1 = false; + bool verified_Mean = false; + bool verified_Variance = false; + + // Verify layernorm output + if (!verified_C1) { + verified_C1 = verify_tensor(tensor_C1, reference_C1, options.problem_size1.m(), options.problem_size1.n()); + } + + if (!verified_Variance) { + verified_Variance = verify_tensor(tensor_Variance, reference_Variance, 1, options.problem_size0.n()); + } + + if (!verified_Mean) { + verified_Mean = verify_tensor(tensor_Mean, reference_Mean, 1, options.problem_size0.n()); + } + + if (!verified_C1 || !verified_Mean || !verified_Variance) { + + // emit_results(); + + std::cerr << "Verification check failed for tensor Layernorm" << std::endl; + + // Summarize which checks failed + if (!verified_C1) { + std::cerr << "Verification of O tensor failed\n"; + } + + if (!verified_Mean) { + std::cerr << "Verification of Mean tensor failed\n"; + } + + if (!verified_Variance) { + std::cerr << "Verification of Variance tensor failed\n"; + } + + return false; + } + + return true; + } + + /// Profiles + bool profile() { + + // + // Profile + // + + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t result; + cudaEvent_t events[2]; + int const kIterations = options.iterations; + + for (cudaEvent_t &evt : events) { + result = cudaEventCreate(&evt); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + } + + result = cudaEventRecord(events[0]); + + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + for (int iter = 0; iter < kIterations; ++iter) { + + status = execute_device_kernel(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Device execution failed." << std::endl; + return false; + } + } + + result = cudaEventRecord(events[1]); + + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + float elapsed_ms = 0; + result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); + + float elapsed_ms_per_iter = elapsed_ms / float(kIterations); + + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + for (cudaEvent_t &evt : events) { + result = cudaEventDestroy(evt); + if (result != cudaSuccess) { + std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + } + + int64_t flops = int64_t(options.problem_size0.m()) * options.problem_size0.n() * options.problem_size0.k() * 2 \ + + int64_t(options.problem_size1.m()) * options.problem_size1.n() * options.problem_size1.k() * 2; + + double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9); + + std::cout << " 1st GEMM: " + << options.problem_size0.m() << "-by-" << options.problem_size0.n() << "-by-" << options.problem_size0.k() << "\n" + << " 2nd GEMM: " + << options.problem_size1.m() << "-by-" << options.problem_size1.n() << "-by-" << options.problem_size1.k() + << std::endl; + + std::cout << " Runtime / iteration: " << elapsed_ms_per_iter << " ms\n" << std::endl; + std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, const char **argv) { + + // Define final layout + using LayoutOutput = cutlass::layout::ColumnMajor; + + // Options parsing + Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (!options.supported()) { + return 0; + } + + // Run + Testbed testbed(options); + + Disposition disposition = testbed.run(); + + std::cout << std::endl; + + switch (disposition) { + case Disposition::kPassed: + std::cout << "Passed" << std::endl; + break; + case Disposition::kIncorrect: + std::cout << "Incorrect" << std::endl; + break; + case Disposition::kNotVerified: + std::cout << "Not verified" << std::endl; + break; + } + + return (disposition == Disposition::kPassed ? 0 : -1); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h new file mode 100644 index 00000000..0323139c --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized layernorm partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithEpilogueVisitor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + + TensorRefA ref_A; + TensorRefB ref_B; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm) + { } + + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_, + TensorRefA ref_A_, + TensorRefB ref_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_ + ): + mode(mode_), + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + epilogue_visitor(epilogue_visitor_) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + + GemmUniversalMode mode; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr) + { } + + + Params( + Arguments const &args + ): + problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + mode(args.mode), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(args.problem_size.k(), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + threadblock_offset); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h new file mode 100644 index 00000000..654ca40f --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h @@ -0,0 +1,1066 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this layernormware without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A file contains all functioning classes needed by GemmLayernorm. + + GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) + + lightweight full reduction kernel (ApplyFinalReduction) + + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_epilogue_visitor.h" +#include "helper.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementOutput, + typename ThreadblockShape_, + bool IsShiftedVariance_ = false +> +class ApplyFinalReduction { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + using ThreadblockShape = ThreadblockShape_; + + // Pre-processing has ensured the layout equivelent to RowMajor + using Layout = cutlass::layout::RowMajor; + + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // + // Arguments + // + + struct Arguments { + + MatrixCoord extent; ///< Extent of D and Layernorm matrices + TensorVariance ref_Variance; ///< Sum Square or Variance tensor (input / output) + TensorMean ref_Mean; ///< Sum or Mean tensor (input / output) + ElementOutput *ptr_Shifted_K; ///< Shifted K tensor pointer + + // + // Methods + // + Arguments(){ } + + Arguments( + MatrixCoord extent_, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + ElementOutput *ptr_Shifted_K_ + ): + extent(extent_), + ref_Variance(ref_Variance_), + ref_Mean(ref_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + ApplyFinalReduction() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + apply(params, shared_storage); + } + +private: + + /// Partial reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + int threadblock_num = (params.args.extent.column() + ThreadblockShape::kM - 1) / ThreadblockShape::kM; + + int block_n = blockIdx.x * blockDim.x; + + int thread_n = threadIdx.x; + + int idx_n = block_n + thread_n; + + if (idx_n >= params.args.extent.row()) { + return; + } + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + using ConvertVariance = cutlass::NumericConverter; + using ConvertMean = cutlass::NumericConverter; + + using ConvertShiftK = cutlass::NumericConverter; + + ConvertVariance convert_variance; + ConvertMean convert_mean; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + ElementVariance *access_square = params.args.ref_Variance.data() + idx_n; + ElementMean *access_mean = params.args.ref_Mean.data() + idx_n; + + ElementVariance *access_square_bak = access_square; + ElementMean *access_mean_bak = access_mean; + + ElementLayernormCompute frag_square_sum = ElementLayernormCompute(0); + ElementLayernormCompute frag_element_sum = ElementLayernormCompute(0); + ElementVariance fetch_square; + ElementMean fetch_mean; + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { + arch::global_load(fetch_square, access_square, true); + arch::global_load(fetch_mean, access_mean, true); + frag_element_sum += convert_mean(fetch_mean); + frag_square_sum += convert_variance(fetch_square); + access_square += params.args.extent.row(); + access_mean += params.args.extent.row(); + } + + ElementLayernormCompute mean = frag_element_sum; + ElementLayernormCompute square_mean = frag_square_sum; + + ElementLayernormCompute variance; + + if (kIsShiftedVariance && params.args.ptr_Shifted_K != nullptr) { + ElementOutput *access_shift_k = params.args.ptr_Shifted_K + idx_n; + ElementOutput fetch_shift_k; + ConvertShiftK convert_shift_k; + arch::global_load(fetch_shift_k, access_shift_k, true); + ElementLayernormCompute shifted_mean = mean - convert_shift_k(fetch_shift_k); + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - shifted_mean * shifted_mean + ElementLayernormCompute(1e-6)); + }else{ + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)); + } + + mean = -mean * variance; + + access_square = access_square_bak; + access_mean = access_mean_bak; + + access_square[0] = convert_variance_output(variance); + access_mean[0] = convert_mean_output(mean); + + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ThreadblockShape_, + int ThreadCount, + typename OutputTileIterator_, + typename AccumulatorTile_, + typename ElementAccumulator_, + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementwiseFunctor_, + bool IsShiftedVariance_ = false +> +class EpilogueVisitorLayerNorm { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + + using AccumulatorTile = AccumulatorTile_; + + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; + + static int const kThreads = OutputTileIterator::ThreadMap::kThreads; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + using ElementOutput = typename OutputTileIterator::Element; + + static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; + + /// Array type used in Shift-K Layernorm + static int const kRowAccessCount = kIterations * kRowIterations; + + using ConvertedShiftFragment = Array; + + // Conducts manual transpose externally (already supported) for column major + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementAccumulator = ElementAccumulator_; + + using AccumulatorFragment = Array; + using LayernormFragment = Array; + using OutputVector = Array; + using TensorRefD = TensorRef; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; + static int const kThreadsInColumn = kThreads / kThreadsPerRow; + static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); + + /// Argument structure + struct Arguments { + + typename ElementwiseFunctor::Params elementwise; + TensorRefD ref_C; + TensorRefD ref_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + Arguments(): + ptr_Variance(nullptr), + ptr_Mean(nullptr), + ptr_Shifted_K(nullptr) + { + + } + + Arguments( + typename ElementwiseFunctor::Params elementwise_, + TensorRefD ref_C_, + TensorRefD ref_D_, + ElementVariance *ptr_Variance, + ElementMean *ptr_Mean_, + ElementOutput *ptr_Shifted_K_ = nullptr + ): + elementwise(elementwise_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Variance(ptr_Variance), + ptr_Mean(ptr_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct Params { + + typename ElementwiseFunctor::Params elementwise; + typename OutputTileIterator::Params params_C; + typename OutputTileIterator::Params params_D; + typename OutputTileIterator::Element *ptr_C; + typename OutputTileIterator::Element *ptr_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(): + ptr_D(nullptr), + ptr_Variance(nullptr), + ptr_Mean(nullptr) + { + + } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + elementwise(args.elementwise), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Variance(args.ptr_Variance), + ptr_Mean(args.ptr_Mean), + ptr_Shifted_K(args.ptr_Shifted_K) + { + + } + }; + + /// Shared storage + struct SharedStorage { + + }; + +private: + + Params const & params_; + SharedStorage & shared_storage_; + MatrixCoord extent_; + ElementwiseFunctor elementwise_; + + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator alpha_; + ElementAccumulator beta_; + ConvertedShiftFragment shift_k_frag_; + + ElementLayernormCompute accum_sum_square_; + ElementLayernormCompute accum_sum_element_; + + MatrixCoord thread_offset_; + +public: + + CUTLASS_DEVICE + EpilogueVisitorLayerNorm( + Params const ¶ms, ///< Parameters routed to the epilogue + SharedStorage &shared_storage, ///< Shared storage needed by the functors here + MatrixCoord const &problem_size0, ///< Problem size of the output + int thread_idx, ///< Thread index within the threadblock + int warp_idx, ///< Warp index within the threadblock + int lane_idx, ///< Lane index within the warp + MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) + ): + params_(params), + shared_storage_(shared_storage), + extent_(problem_size0), + elementwise_(params.elementwise), + iterator_C_(params.params_C, params.ptr_C, problem_size0, thread_idx, threadblock_offset), + iterator_D_(params.params_D, params.ptr_D, problem_size0, thread_idx, threadblock_offset) + { + alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + + // If shift-K feature is enabled, we load shift-k fragment + // at the very beginning of an epilogue + if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { + shift_k_frag_.clear(); + int thread_offset_row_base = iterator_D_.thread_start_row(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { + int step_offset = iter_idx * OutputTileIterator::Shape::kRow; + CUTLASS_PRAGMA_UNROLL + for (int rid = 0; rid < kRowIterations; ++rid) { + int row_step_offset = rid * kDeltaRow; + int row_offset = thread_offset_row_base + step_offset + row_step_offset; + bool is_load = (row_offset < extent_.row()); + shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); + } + + } + + } + + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + fragment_C_.clear(); + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const &accum) { + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + Minus minus; + Mul mul; + Exp exponential; + + LayernormFragment result; + + thread_offset_ = + iterator_D_.thread_start() + + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + + NumericArrayConverter source_converter; + OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; + + bool column_guard = (thread_offset_.column() < extent_.column()); + + if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + result = source_converter(elementwise_(accum)); + }else{ + result = source_converter(elementwise_(accum, source_vector)); + } + + + ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); + + // Fragment is cleared for non-reachable columns so no need to check against column guard + accum_sum_element_ = element_sum_accumulator_(result); + + // Square sum is different. Non-reachable columns should've been computed for shift-k + // Otherwise we will incorrectly have some extra k^2 added into square sum. + if (column_guard) { + accum_sum_square_ = (kIsShiftedVariance) ? \ + square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ + square_sum_accumulator_(result); + } + else { + accum_sum_square_ = ElementLayernormCompute(0); + } + + accum_sum_element_ *= inv_scalar; + accum_sum_square_ *= inv_scalar; + + // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction + CUTLASS_PRAGMA_UNROLL + for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { + accum_sum_element_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_, i); + accum_sum_square_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_, i); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void end_row(int row_idx) { + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); + int row_offset = thread_offset_.row() + blockIdx.y * extent_.row(); + + ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; + ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; + + arch::global_store( + convert_variance_output(accum_sum_square_), + (void *)curr_ptr_sum_square, + is_write_thread); + + arch::global_store( + convert_mean_output(accum_sum_element_), + (void *)curr_ptr_element_sum, + is_write_thread); + + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() { + + } + +private: + + CUTLASS_DEVICE + ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { + using ConvertShiftK = cutlass::NumericConverter; + ConvertShiftK convert_shift_k; + ElementOutput shift_k_val; + + // Computes the address to load shift_k element + ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; + // Conditionally loads from global memory + arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); + // Converts data type to return + ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); + + return converted_shift_k_val; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i]; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i] - shift_k_val; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + sum_ += accum[i]; + } + + return sum_; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename ElementInputA0_, + typename LayoutInputA0_, + typename ElementInputB0_, + typename LayoutInputB0_, + typename ElementOutput_, + typename LayoutOutput_, + typename ElementCompute_, + typename EpilogueFunctorOp_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + int Stages0, + int Stages1, + bool IsShiftedVariance_ = false +> +class GemmLayernorm { +public: + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // + // Type definitions + // + + static bool const kInternalTranspose = cutlass::platform::is_same::value; + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // These is mandatory layout. + using LayoutInputScaleBias = cutlass::layout::RowMajor; + + // These are mandatory data types. + using ElementLayernormCompute = float; + using ElementInputScaleBias = cutlass::half_t; + + // These are mandatory params required by mainloop fusion + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + // These are mandatory layouts and data types + // that are inheritated from pre-defined params + + using LayoutSumSqr = LayoutInputScaleBias; + using LayoutSum = LayoutInputScaleBias; + + using ElementMean = ElementInputScaleBias; + using ElementVariance = ElementInputScaleBias; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using LayoutInputA0 = LayoutInputA0_; + using LayoutInputB0 = LayoutInputB0_; + using LayoutInputA1 = LayoutOutput_; + using LayoutInputB1 = LayoutOutput_; + using LayoutOutputC0 = LayoutOutput_; + using LayoutOutputC1 = LayoutOutput_; + + using ElementInputA0 = ElementInputA0_; + using ElementInputB0 = ElementInputB0_; + using ElementOutputC0 = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementInputB1 = ElementInputB0_; + + using ElementInputA1 = ElementOutputC0; + using ElementOutputC1 = ElementOutputC0; + + using EpilogueFunctorOp = EpilogueFunctorOp_; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + + static int const kStages0 = Stages0; + static int const kStages1 = Stages1; + + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using MapArguments = cutlass::gemm::kernel::detail::MapArguments< + ElementInputA0, + LayoutInputA0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + ElementInputB0, + LayoutInputB0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + LayoutOutputC0, + kInternalTranspose + >; + + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementOutputC0, + typename MapArguments::LayoutC, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages0, + true, + typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementInputA0, ElementInputB0, ElementOutputC0, ElementCompute>::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone + >::GemmKernel; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // Epilogue visitor + using EpilogueVisitor = kernel::EpilogueVisitorLayerNorm< + ThreadblockShape, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + typename DefaultGemmKernel::Epilogue::AccumulatorFragmentIterator::AccumulatorTile, + ElementCompute, + ElementVariance, + ElementMean, + ElementLayernormCompute, + EpilogueFunctorOp, + kIsShiftedVariance + >; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue + >::Epilogue; + + // GEMM + using GemmEpilogueFusion = gemm::kernel::GemmWithEpilogueVisitor< + typename DefaultGemmKernel::Mma, + Epilogue, + SwizzleThreadBlock + >; + + using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< + ElementVariance, + ElementMean, + ElementLayernormCompute, + ElementOutputC0, + ThreadblockShape, + kIsShiftedVariance + >; + +using GemmMainloopFusion = typename cutlass::gemm::device::GemmLayernormMainloopFusion< + ElementInputA1, LayoutInputA1, + ElementInputB1, LayoutInputB1, + ElementInputScaleBias, LayoutInputScaleBias, + ElementOutputC1, LayoutOutputC1, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages1 +>; + +public: + + /// Arguments class + struct Arguments { + + typename GemmEpilogueFusion::Arguments gemm0; + typename GemmMainloopFusion::Arguments gemm1; + typename ApplyFinalReductionKernel::Arguments reduction; + cutlass::gemm::GemmCoord extend; + + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size0, + cutlass::gemm::GemmCoord problem_size1, + ElementInputA0 * ptr_A, + ElementInputB0 * ptr_B, + ElementOutputC0 * ptr_C, + ElementOutputC0 * ptr_D, + ElementOutputC0 * ptr_E, + ElementOutputC0 * ptr_O, + int64_t ldm_A, + int64_t ldm_B, + int64_t ldm_C, + int64_t ldm_D, + int64_t ldm_E, + int64_t ldm_O, + typename EpilogueFunctorOp::Params linear_scaling, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + TensorVariance ref_Gamma_, + TensorMean ref_Beta_, + ElementOutputC0 *ptr_Shifted_K = nullptr + ): + gemm0( + cutlass::gemm::GemmUniversalMode::kGemm, + {kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n(),\ + problem_size0.k()}, + {kInternalTranspose ? ptr_B : ptr_A, \ + kInternalTranspose ? ldm_B : ldm_A}, + {kInternalTranspose ? ptr_A : ptr_B, \ + kInternalTranspose ? ldm_A : ldm_B}, + typename EpilogueVisitor::Arguments( + linear_scaling, + {ptr_C, ldm_C}, + {ptr_D, ldm_D}, + ref_Variance_.data(), + ref_Mean_.data(), + ptr_Shifted_K + ) + ), + reduction( + MatrixCoord(kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n()), + ref_Variance_, + ref_Mean_, + ptr_Shifted_K + ), + gemm1( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size1, + 1, + linear_scaling, + kInternalTranspose ? ptr_E : ptr_D, + kInternalTranspose ? ptr_D : ptr_E, + ref_Variance_.data(), + ref_Mean_.data(), + ref_Gamma_.data(), + ref_Beta_.data(), + ptr_O, + ptr_O, + problem_size1.m() * problem_size1.k(), + problem_size1.n() * problem_size1.k(), + problem_size1.n(), + problem_size1.n(), + problem_size1.k(), + problem_size1.k(), + problem_size1.m() * problem_size1.n(), + problem_size1.m() * problem_size1.n(), + kInternalTranspose ? ldm_E : ldm_D, + kInternalTranspose ? ldm_D : ldm_D, + ref_Variance_.layout().stride(0), + ref_Mean_.layout().stride(0), + ref_Gamma_.layout().stride(0), + ref_Beta_.layout().stride(0), + ldm_O, + ldm_O + ), + extend(problem_size0) + { + + } + }; + + struct Params { + + typename GemmEpilogueFusion::Params gemm0; + typename ApplyFinalReductionKernel::Params reduction; + MatrixCoord extend; + // + // Methods + // + Params() { } + + Params(Arguments const &args): + gemm0(args.gemm0), + reduction(args.reduction), + extend(MatrixCoord(args.extend.m(), args.extend.n())) + { + + } + }; + +public: + + // Gemm + + + // + // Methods + // + +private: + + Params params_; + GemmMainloopFusion gemm_fusion_op; + +public: + + /// Ctor + GemmLayernorm() { + + } + + /// Initialize + Status initialize(Arguments const &args) { + + params_ = Params(args); + cutlass::Status status; + size_t workspace_size = gemm_fusion_op.get_workspace_size(args.gemm1); + cutlass::device_memory::allocation workspace(workspace_size); + status = gemm_fusion_op.can_implement(args.gemm1); + CUTLASS_CHECK(status); + + status = gemm_fusion_op.initialize(args.gemm1, workspace.get()); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Run + Status run(cudaStream_t stream) { + + // + // Launch the GEMM + layernorm kernel + // + + dim3 gemm_grid = SwizzleThreadBlock().get_grid_shape(params_.gemm0.grid_tiled_shape); + dim3 gemm_block(GemmEpilogueFusion::kThreadCount, 1, 1); + + int gemm_smem_size = int(sizeof(typename GemmEpilogueFusion::SharedStorage)); + + cutlass::Kernel<<>>(params_.gemm0); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the ApplyFinalReductionKernel + // + + // always performs reduction from leading dimension + int leading_dim_0 = kInternalTranspose ? params_.extend.row() : params_.extend.column(); + int leading_dim_1 = kInternalTranspose ? params_.extend.column() : params_.extend.row(); + + int thread_per_block = 128; + int block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + } + + dim3 final_reduction_block(thread_per_block); + dim3 final_reduction_grid(block_per_row); + + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the GEMM + mainloop fusion kernel + // + + cutlass::Status status = gemm_fusion_op(); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/38_syr2k_grouped/CMakeLists.txt b/examples/38_syr2k_grouped/CMakeLists.txt new file mode 100644 index 00000000..9153cbd9 --- /dev/null +++ b/examples/38_syr2k_grouped/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 38_syr2k_grouped + syr2k_grouped.cu + ) + diff --git a/examples/38_syr2k_grouped/syr2k_grouped.cu b/examples/38_syr2k_grouped/syr2k_grouped.cu new file mode 100644 index 00000000..245ef6ac --- /dev/null +++ b/examples/38_syr2k_grouped/syr2k_grouped.cu @@ -0,0 +1,1461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief SYR2K Grouped Example. + + This workload computes a batch of SYR2K operations with distinct problem sizes. This example closely + follows 24_gemm_grouped. + + Examples: + + # Runs a grouped SYR2K with 100 random problem sizes + $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 + + # Runs a grouped SYR2K with 100 random problem sizes (with SYR2K-K dimension equal to 1024) + $ ./examples/38_syr2k_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true + + # Runs a grouped SYR2K that is equivalent to a batched SYR2K + $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true + + # Execute grouped SYR2K and profile with NSight + $ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true \ + --iterations=1 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include + +#include "cutlass/blas3.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double initialization_time_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double initialization_time_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), + status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool profile_initialization; + bool sort_problems; + + std::vector problem_sizes; + + int alignment; + int problem_count; + int iterations; + int cuda_streams; + bool verbose; + float alpha; + float beta; + std::string benchmark_path; + + std::string output_tag; + std::ofstream output_file; + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + std::vector scheduler_modes; + + std::unordered_map + str_to_scheduler_mode = { + {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, + {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} + }; + + struct GroupScheduleModeHash { + size_t operator()(GroupScheduleMode m) const { + return static_cast(m); + } + }; + + std::unordered_map + scheduler_mode_to_str = { + {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, + {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} + }; + + std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(8), + reference_check(true), + profile_initialization(false), + sort_problems(false), + problem_count(5), + iterations(20), + cuda_streams(0), + verbose(false), + alpha(1), + beta(), + scheduler_modes({GroupScheduleMode::kDeviceOnly}) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 8); + cmd.get_cmd_line_argument("groups", problem_count, 5); + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("verbose", verbose, false); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); + cmd.get_cmd_line_argument("sort-problems", sort_problems, false); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + scheduler_modes.clear(); + if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { + scheduler_modes = all_scheduler_modes; + } else { + for (std::string precomp_str : scheduler_mode_strs) { + auto it = str_to_scheduler_mode.find(precomp_str); + if (it != str_to_scheduler_mode.end()) { + scheduler_modes.push_back(it->second); + } else if (precomp_str == "all") { + std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; + error = true; + return; + } else { + std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; + error = true; + return; + } + } + } + } + + std::string output_path; + cmd.get_cmd_line_argument("tag", output_tag); + cmd.get_cmd_line_argument("output_file", output_path); + + if (!output_path.empty()) { + + std::ios_base::openmode open_mode = std::ios_base::out; + + std::ifstream input_file(output_path.c_str()); + + if (input_file.good()) { + open_mode = std::ios_base::app; + input_file.close(); + } + + output_file.open(output_path.c_str(), open_mode); + + if (output_file.good() && open_mode != std::ios_base::app) { + output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; + } + } + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + error = true; + problem_sizes.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + + // + // For now, randomly choose the problem sizes. + // + + int cmd_line_m = -1; + int cmd_line_n = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + // SYR2K is defined via only N and K. + if (cmd_line_m != -1) { + std::cerr << "Parameter M is ignored for SYR2K\n"; + error = true; + return; + } + + problem_sizes.reserve(problem_count); + + for (int i = 0; i < problem_count; ++i) { + int n = cmd_line_n; + int k = cmd_line_k; + + if (n < 1) { + n = alignment * ((rand() % 256) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 256) + 1); + } + + // SYR2K is defined only in terms of N and K. Replicate N into + // the SYR2K-N dimension. + cutlass::gemm::GemmCoord problem(n, n, k); + + problem_sizes.push_back(problem); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes.push_back(extent); + } + } + + problem_count = int(problem_sizes.size()); + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "38_syr2k_grouped\n\n" + << " This example profiles the performance of a 'grouped' SYR2K kernel. This example closely follows 24_gemm_grouped\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --benchmark= Executes a benchmark problem size.\n" + << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" + << " --tag= String tag to prepend to the CSV file.\n" + << " --groups= Number of individual SYR2K problems (default: --groups=15)\n" + << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" + << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n" + << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n" + << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" + << " --sort-problems= If true, sorts problem sizes in descending order of SYR2K-K dimension.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a grouped SYR2K with 100 random problem sizes\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100\n\n" + + << "# Runs a grouped SYR2K with 100 random problem sizes (with K dimension equal to 1024)\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped SYR2K that is equivalent to a batched SYR2K\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped SYR2K with each different scheduler mode\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all\n\n" + + << "# Runs a grouped SYR2K with each different scheduler mode and profiles host-side initialization time\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all --profile-initialization=true\n\n" + + << "# Runs a grouped SYR2K problem given an externally supplied benchmark file. This is a text file in which\n" + << "# Each line contains a unique group index and an MxNxK triple indicating problemsize. NOTE that the\n" + << "# GEMM-M and GEMM-N dimensions must match.\n" + << "#\n" + << "# For example, assume the following are the contents of 'problems.txt'\n" + << "#\n" + << "# 0 256x256x520\n" + << "# 1 264x264x1024\n" + << "# 2 48x48x1024\n" + << "#\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n" + + << "# Execute Grouped SYR2K and profile with NSight\n" + << "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + for (auto const & problem : problem_sizes) { + fmas += problem.product(); + } + + // SYR2K is defined as (A x BT) + (B x AT), so the number of FMAs is twice that in a GEMM + fmas *= 2; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BaseTestbed { +public: + // + // Type definitions + // + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + using EpilogueOutputOp = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Rank2K::LayoutA; + using LayoutB = typename Rank2K::LayoutB; + using LayoutC = typename Rank2K::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + BaseTestbed( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + int problem_count() const { + return options.problem_count; + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Allocates device-side data + void allocate() { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + lda_host.resize(problem_count()); + ldb_host.resize(problem_count()); + ldc_host.resize(problem_count()); + ldd_host.resize(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + + auto problem = options.problem_sizes.at(i); + + lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.n() * problem.k(); + int64_t elements_B = problem.n() * problem.k(); + int64_t elements_C = problem.n() * problem.n(); + int64_t elements_D = problem.n() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + } + + lda.reset(problem_count()); + ldb.reset(problem_count()); + ldc.reset(problem_count()); + ldd.reset(problem_count()); + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + } + + /// Initializes device-side data + void initialize() { + problem_sizes_device.reset(problem_count()); + problem_sizes_device.copy_from_host(options.problem_sizes.data()); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(problem_count()); + std::vector ptr_B_host(problem_count()); + std::vector ptr_C_host(problem_count()); + std::vector ptr_D_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count()); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count()); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count()); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count()); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); + initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); + initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); + + cutlass::reference::device::BlockFillSequential( + block_D.get(), block_D.size(), ElementC(), ElementC()); + } + + /// Verifies the result is a SYR2K + bool verify() { + + bool passed = true; + + for (int32_t i = 0; i < problem_count(); ++i) { + cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + cutlass::HostTensor host_A( + typename LayoutA::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); + cutlass::HostTensor host_B( + typename LayoutB::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); + cutlass::HostTensor host_C( + typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); + cutlass::HostTensor host_D( + typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); + + cutlass::device_memory::copy_to_host(host_A.host_data(), block_A.get() + offset_A.at(i), problem.n() * problem.k()); + cutlass::device_memory::copy_to_host(host_B.host_data(), block_B.get() + offset_B.at(i), problem.n() * problem.k()); + cutlass::device_memory::copy_to_host(host_C.host_data(), block_C.get() + offset_C.at(i), problem.n() * problem.n()); + cutlass::reference::host::BlockFillSequential( + host_D.host_data(), problem.n() * problem.n(), ElementC(), ElementC()); + + MatrixCoord extent_C{problem.n(), problem.n()}; + + // Reference Rank2K + cutlass::reference::host::Rank2KComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementC, ElementAccumulator + >( + problem, + (double)options.alpha, + host_A.host_view(), + Rank2K::kTransformA, + host_B.host_view(), + Rank2K::kTransformB, + (double)options.beta, + host_C.host_view(), + host_D.host_view(), + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + // Copy to host memory + std::vector matrix_D(layout_D.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref = host_D.host_view(); + + // Reference check + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + return passed; + } + } + + return passed; + } +}; + +template +class TestbedConventional : BaseTestbed { +public: + TestbedConventional( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + /// Verbose printing of problem sizes + void print_problem_sizes() { + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = + ((problem.m() + Rank2K::ThreadblockShape::kM - 1) / Rank2K::ThreadblockShape::kM) * + ((problem.n() + Rank2K::ThreadblockShape::kN - 1) / Rank2K::ThreadblockShape::kN); + + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Executes a conventional SYR2K kernel. + Result profile() { + std::cout << "Conventional Rank2K:\n" + << "====================================================" << std::endl; + + Result result; + result.passed = false; + + // Initialize the problem + this->allocate(); + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // + // Create CUDA streams to maximize concurrency of SYR2K kernels + // + int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); + std::vector cuda_streams; + char const *provider = "CUTLASS"; + + // + // Warmup run + // + + if (this->options.cuda_streams) { + for (int i = 0; i < this->options.cuda_streams; ++i) { + cudaStream_t stream; + + result.error = cudaStreamCreate(&stream); + if (result.error != cudaSuccess) { + std::cerr << "Failed to create CUDA stream." << std::endl; + return result; + } + cuda_streams.push_back(stream); + } + } + else { + cuda_streams.push_back(nullptr); + } + + // Use 'D' for the in/out workspace + this->block_D.copy_from_device(this->block_C.get()); + + for (int i = 0; i < this->options.problem_sizes.size(); ++i) { + cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; + int32_t batch_count = 1; + int64_t lda = this->lda_host.at(i); + int64_t ldb = this->ldb_host.at(i); + int64_t ldc = this->ldc_host.at(i); + typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); + typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); + typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); + typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); + + // + // Initialize the CUTLASS SYR2K operator + // + + // Configure the SYR2K arguments + typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + typename Rank2K::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + epilogue_op, + (void const *)ptrA, + (void const *)ptrB, + (void const *)ptrC, + (void *)ptrD, + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(lda), + int64_t(ldb), + int64_t(ldc), + int64_t(ldc) + }; + + Rank2K rank2k_op; + + cutlass::Status status = rank2k_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + + status = rank2k_op(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + } + + // + // Wait for completion + // + + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // + // Wait for completion + // + + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // Record an event at the start of a series of SYR2K operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + int last_stream_idx = 0; + + for (int iter = 0; iter < this->options.iterations; ++iter) { + for (int i = 0; i < this->options.problem_sizes.size(); ++i) { + cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; + int32_t batch_count = 1; + int64_t lda = this->lda_host.at(i); + int64_t ldb = this->ldb_host.at(i); + int64_t ldc = this->ldc_host.at(i); + typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); + typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); + typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); + typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); + + last_stream_idx = (i % effective_streams); + + // + // Initialize the CUTLASS SYR2K operator + // + + // Configure the SYR2K arguments + typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + typename Rank2K::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + epilogue_op, + (void const *)ptrA, + (void const *)ptrB, + (void const *)ptrC, + (void *)ptrD, + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(lda), + int64_t(ldb), + int64_t(ldc), + int64_t(ldc) + }; + + Rank2K rank2k_op; + + cutlass::Status status = rank2k_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + + status = rank2k_op(cuda_streams[last_stream_idx]); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + } + } + + // + // Stop profiling loop + // + + // Record an event when the SYR2K operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Wait for work to be completed + // + + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + for (auto stream : cuda_streams) { + if (stream) { + (void)cudaStreamDestroy(stream); + } + } + + std::cout << " " << this->options.problem_sizes.size() << " conventional Rank2Ks launched" << std::endl; + std::cout << std::endl; + std::cout << " " << "Conventional Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Conventional GFLOPS: " << result.gflops << std::endl; + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << "," << provider << ",conventional," + << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + result.passed = true; + return result; + } +}; + +template +class TestbedGrouped : BaseTestbed { +public: + TestbedGrouped( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ) : BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + // Redefine Rank2K with different GroupScheduleMode_ + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + typename Rank2K_::ElementA, typename Rank2K_::LayoutA, Rank2K_::kTransformA, Rank2K_::kAlignmentA, + typename Rank2K_::ElementB, typename Rank2K_::LayoutB, Rank2K_::kTransformB, Rank2K_::kAlignmentB, + typename Rank2K_::ElementC, typename Rank2K_::LayoutC, Rank2K_::kFillModeC, + typename Rank2K_::ElementAccumulator, + typename Rank2K_::OperatorClass, + typename Rank2K_::ArchTag, + typename Rank2K_::ThreadblockShape, + typename Rank2K_::WarpShape, + typename Rank2K_::InstructionShape, + typename Rank2K_::EpilogueOutputOp, + typename Rank2K_::ThreadblockSwizzle, + Rank2K_::kStages, + typename Rank2K_::Operator::ArchMmaOperator::Operator, + Rank2K_::kBlasMode, + GroupScheduleMode_>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + /// Verbose printing of problem sizes + void print_problem_sizes() { + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = Rank2K::problem_tile_count(problem); + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Sort problems in descending order of problem-K dimension + void sort_problems() { + Rank2K::sort_problems(this->options.problem_count, + this->options.problem_sizes.data(), + this->lda_host.data(), + this->ldb_host.data(), + this->ldc_host.data(), + this->ldd_host.data(), + this->offset_A.data(), + this->offset_B.data(), + this->offset_C.data(), + this->offset_D.data()); + } + + /// Executes a grouped kernel and measures runtime. + Result profile() { + std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; + std::cout << std::endl; + std::cout << "Grouped Rank2K (CUTLASS) with mode " << sched_mode << ":\n" + << "====================================================" << std::endl; + + Result result; + + int threadblock_count = Rank2K::sufficient(this->options.problem_sizes.data(), this->options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped SYR2K kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + this->allocate(); + if (this->options.sort_problems) { + sort_problems(); + } + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // Configure the Rank2K arguments + typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + // Configure Rank2K arguments + typename Rank2K::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + this->problem_sizes_device.get(), + this->problem_count(), + threadblock_count, + epilogue_op, + this->ptr_A.get(), + this->ptr_B.get(), + this->ptr_C.get(), + this->ptr_D.get(), + this->lda.get(), + this->ldb.get(), + this->ldc.get(), + this->ldd.get(), + this->options.problem_sizes.data() + ); + + // Initialize the Rank2K object + Rank2K rank2k; + size_t workspace_size = rank2k.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = rank2k.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped Rank2K kernel." << std::endl; + return result; + } + + // Run the grouped Rank2K object + result.status = rank2k.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + if (this->options.reference_check) { + result.passed = this->verify(); + } + + // + // Warm-up run of the grouped Rank2K object + // + result.status = rank2k.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of SYR2K operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + rank2k(); + } + + // + // Stop profiling loop + // + + // Record an event when the Rank2K operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + // Optionally profile initialization + if (this->options.profile_initialization) { + // Warm up + rank2k.initialize(args, workspace.get()); + + auto start_time = std::chrono::high_resolution_clock::now(); + for (int32_t i = 0; i < this->options.iterations; ++i) { + rank2k.initialize(args, workspace.get()); + } + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + duration /= double(this->options.iterations); + result.initialization_time_ms = duration.count(); + } + + int64_t total_tiles = Rank2K::group_tile_count(args); + std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; + + std::cout << std::endl; + std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; + if (this->options.profile_initialization) { + std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; + } + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," + << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + std::cout << "\nPassed\n"; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's Grouped Rank2K example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Define the Grouped and Conventional Rank2K types + // + + using ElementA = double; + using ElementB = double; + using ElementOutput = double; + using ElementAccumulator = double; + const cutlass::FillMode kFillModeC = cutlass::FillMode::kLower; + const int kAlignmentA = 1; + const int kAlignmentB = 1; + const cutlass::ComplexTransform kTransformA = cutlass::ComplexTransform::kNone; + const cutlass::ComplexTransform kTransformB = cutlass::ComplexTransform::kNone; + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 1, + ElementAccumulator, ElementAccumulator>; + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + const int kStages = 4; + const bool kSplitKSerial = false; + using Operator = cutlass::arch::OpMultiplyAdd; + const cutlass::BlasMode kBlasMode = cutlass::BlasMode::kSymmetric; + + // Define a grouped Rank2K kernel with all template parameters set except + // for scheduling mode. This will be used as the template for all scheduling + // modes executed. + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, kTransformA, kAlignmentA, + ElementB, LayoutB, kTransformB, kAlignmentB, + ElementOutput, LayoutC, kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + Operator, + kBlasMode>::Rank2Kkernel; + + using Rank2KGrouped = cutlass::gemm::device::Rank2KGrouped; + + // Rank2k operator + using Rank2KConventional = cutlass::gemm::device::Rank2K< + ElementA, LayoutA, + ElementB, LayoutB, + ElementOutput, LayoutC, kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kAlignmentA, + kAlignmentB, + kSplitKSerial, + Operator, + kTransformA, + kTransformB, + kBlasMode + >; + + // + // Profile it + // + + TestbedConventional testbed(options); + + Result result = testbed.profile(); + if (!result.passed) { + std::cout << "Profiling CUTLASS conventional Rank2K has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + for (GroupScheduleMode mode : options.scheduler_modes) { + Result result; + switch (mode) { + case GroupScheduleMode::kDeviceOnly: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + case GroupScheduleMode::kHostPrecompute: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + } + + if (result.error != cudaSuccess) { + return 1; + } + + // Override verbose flag to avoid printing duplicate information for each scheduling mode + options.verbose = false; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/39_gemm_permute/CMakeLists.txt b/examples/39_gemm_permute/CMakeLists.txt new file mode 100644 index 00000000..d503fcac --- /dev/null +++ b/examples/39_gemm_permute/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 39_gemm_permute + gemm_permute.cu + ) + diff --git a/examples/39_gemm_permute/gemm_permute.cu b/examples/39_gemm_permute/gemm_permute.cu new file mode 100644 index 00000000..b4649b83 --- /dev/null +++ b/examples/39_gemm_permute/gemm_permute.cu @@ -0,0 +1,1126 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief GEMM Permute Example. + + This example computes batched GEMM operations with output results permuted as reshaped tensors. + + We provide layout plugin as a flexible tool for users to add any customized output tensor permute operation, + or any other generalized global memory writeout address computation. To add a customized layout, add new class + in include/cutlass/layout/permute.h + + In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM + whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on + output matrix. The address computations are performed in compute(col_init, row_init, stride_init, + BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op. + (check include/cutlass/layout/permute.h) + + Tips: + + 1) Make sure to set batch_stride_D to zero for BMM permute; Also the BMM GEMM should be in mode + cutlass::gemm::GemmUniversalMode::kBatched instead of kArray + + 2) When the last dimension is touched in permute op (for example permute([0, 2, 3, 1])), AlignmentC should + be set to 1. If the last dimension is untouched, one can set AlignmentC to be larger like 8 in our example. + As a result, permute op without touching the last dimension is recommended to obtain the best performance gain. + + Examples: + + # Runs a batched GEMM with 96 batches + $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 + + # Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024) + $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true + + # Execute batched GEMM and profile with NSight + $ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/layout/permute.h" + +/// Tensor4DPermuteBMM0213 ---> +/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped +/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. +const int D1 = 12; + +/// Tensor5DPermute20314 ---> +/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped +/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. +const int T1 = 16; +const int T2 = 3; +const int T3 = 8; + +// Alignment C +const int AlignmentC = 8; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + + cutlass::gemm::GemmCoord problem_each; + + int batch_count; + int iterations; + int cuda_streams; + bool verbose; + float alpha; + float beta; + + // + // Methods + // + + Options(): + help(false), + error(false), + reference_check(true), + batch_count(-1), + iterations(20), + cuda_streams(0), + verbose(false), + alpha(1), + beta() + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("verbose", verbose, false); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + + int m, n, k; + + cmd.get_cmd_line_argument("m", m, 128); + cmd.get_cmd_line_argument("n", n, 192); + cmd.get_cmd_line_argument("k", k, 128); + cmd.get_cmd_line_argument("batch-count", batch_count, 768); + + cutlass::gemm::GemmCoord problem(m, n, k); + problem_each = problem; + + if (batch_count % D1 != 0){ + std::cerr << "\nProblem count error (problem-count = " << batch_count << "). " + << "problem-count needs to be divided with no remain by " << D1 << " (D1)." + << " (Required by the Batched GEMM permute Tensor4DPermuteBMM0213)\n\n"; + error = true; + } + + if (m % (AlignmentC * T1) != 0){ + std::cerr << "\nProblem m size error (m = " << m << "). " + << "m needs to be divided with no remain by " << (AlignmentC * T1) << " (AlignmentC * T1)." + << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; + error = true; + } + + if (n % (AlignmentC * (T2 * T3)) != 0){ + std::cerr << "\nProblem n size error (n = " << n << "). " + << "n needs to be divided with no remain by " << (AlignmentC * (T2 * T3)) << " (AlignmentC * T2 * T3)." + << " (Required by the normal GEMM permute Tensor5DPermute20314)\n\n"; + error = true; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "39_gemm_permute\n\n" + << " 1) This example firstly profiles the performance of a batched GEMM kernel with BMM whole output" + << " (including output matrices for each batch) as permuted 4D Tensor." + << " The BMM tensor output in shape of [B, M, N] is reshaped as [B/D1, D1, M, N] and then permuted with" + << " permute([0, 2, 1, 3]) to be in shape of [B/D1, M, D1, N].\n\n" + << " 2) This example also profiles the performance of a normal GEMM kernel with output as permuted 5D Tensor." + << " The GEMM matrix output in shape of [M, N] is reshaped as [M/T1, T1, T2, T3, N/T2/T3] and then permuted" + << " with permute([2, 0, 3, 1, 4]) to be in shape of [T2, M/T1, T3, T1, N//T2/T3].\n\n" + << " Note: D1, T1, T2, T3 are compile-time constants defined in gemm_permute.cu\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" + << " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" + << " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" + << " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=128)\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a batched GEMM with 96 batches\n" + << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96\n\n" + + << "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" + << "$ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true\n\n" + + << "# Execute batched GEMM and profile with NSight\n" + << "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + fmas += problem_each.product() * batch_count; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Testbed { +public: + + // + // Type definitions + // + + using ElementA = typename GemmBatched::ElementA; + using ElementB = typename GemmBatched::ElementB; + using ElementC = typename GemmBatched::ElementC; + using ElementAccumulator = typename GemmBatched::ElementAccumulator; + + using EpilogueOutputOp = typename GemmBatched::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename GemmBatched::LayoutA; + using LayoutB = typename GemmBatched::LayoutB; + using LayoutC = typename GemmBatched::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + +private: + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + +public: + + // + // Methods + // + + Testbed( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3090 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Verbose BMM info + void print_BMM_info_() { + + // Print batched GEMM + std::cout << "Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor:\n"; + + auto problem = options.problem_each; + std::cout + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << ", batch count: " << options.batch_count << "\n"; + + std::cout << "output tensor shape: [" << options.batch_count << ", " << problem.m() << ", " + << problem.n() <<"]\n"; + std::cout << "reshaped as: [" << options.batch_count / D1 << ", " << D1 << ", " + << problem.m() << ", " << problem.n() <<"]\n"; + std::cout << "finally permuted as: [" << options.batch_count / D1 << ", " << problem.m() << ", " + << D1 << ", " << problem.n() <<"]\n"; + + std::cout << "----------------------------------------------------\n"; + + } + + /// Verbose normal GEMM info + void print_GEMM_info_() { + + // Print batched GEMM + std::cout << "Normal GEMM with permute([2, 0, 3, 1, 4]):\n"; + + auto problem = options.problem_each; + std::cout + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() << "\n"; + + std::cout << "output tensor shape: [" << problem.m() << ", " << problem.n() <<"]" << std::endl; + std::cout << "reshaped as: [" << problem.m() / T1 << ", " << T1 << ", " + << T2 << ", " << T3 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; + std::cout << "finally permuted as: [" << T2 << ", " << problem.m() / T1 << ", " + << T3 << ", " << T1 << ", " << problem.n() / T2 / T3 <<"]" << std::endl; + + std::cout << "----------------------------------------------------\n"; + + } + +private: + + /// Helper to initialize a tensor view + template + void initialize_tensor_( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Initializes data structures + void initialize_(int batch_count) { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count; + int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count; + int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count; + int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count; + + // + // Assign space + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + // + // Initialize the problems of the workspace + // + + initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); + initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); + initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); + + cutlass::reference::device::BlockFillSequential( + block_D.get(), total_elements_D, ElementC(), ElementC()); + } + + /// Verifies the BMM GEMM result + bool verify_BMM_() { + + bool passed = true; + + cutlass::gemm::GemmCoord problem = options.problem_each; + + LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); + LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); + LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); + LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); + cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); + cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); + + cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C) * options.batch_count); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + options.alpha, + view_A, + GemmBatched::kTransformA, + view_B, + GemmBatched::kTransformB, + options.beta, + view_C, + view_Ref_device, + ElementAccumulator(0), + options.batch_count, + options.problem_each.m() * options.problem_each.k(), + options.problem_each.n() * options.problem_each.k(), + options.problem_each.m() * options.problem_each.n(), + options.problem_each.m() * options.problem_each.n() + ); + + // Copy to host memory + std::vector matrix_D(layout_D.capacity(extent_C) * options.batch_count); + std::vector matrix_Ref(layout_D.capacity(extent_C) * options.batch_count); + + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); + + // Print out the results and reference in 4D Tensor + // [options.batch_count, options.problem_each.m() * options.problem_each.n()] -> [D0, D1, D2, D3]. + // After permute Op, -> [D0, D2, D1, D3]. + int D0 = options.batch_count / D1; + int D2 = options.problem_each.m(); + int D3 = options.problem_each.n(); + + cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently + cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D2, D1, D3})), cutlass::Tensor4DCoord({D0, D2, D1, D3})); + + cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), + cutlass::layout::TensorNHWC().packed(cutlass::Tensor4DCoord({D0, D1, D2, D3})), cutlass::Tensor4DCoord({D0, D1, D2, D3})); + + // Tensor Permute Op on reference tensor + cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor4DCoord({D0, D2, D1, D3})); + for (int n = 0; n < D0; ++n) { + for (int h = 0; h < D1; ++h) { + for (int w = 0; w < D2; ++w) { + for (int c = 0; c < D3; ++c) { + view_Ref_Permute_Tensor.at({n, w, h, c}) = view_Ref_Tensor.at({n, h, w, c}); + } + } + } + } + + // Reference check + passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); + + if (!passed) { + std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; + return passed; + } + + std::cout << "Passed verification" << std::endl; + return passed; + } + + bool verify_GEMM_normal_() { + + bool passed = true; + + cutlass::gemm::GemmCoord problem = options.problem_each; + + LayoutA layout_A(LayoutA::packed({problem.m(), problem.k()}).stride(0)); + LayoutB layout_B(LayoutB::packed({problem.k(), problem.n()}).stride(0)); + LayoutC layout_C(LayoutC::packed({problem.m(), problem.n()}).stride(0)); + LayoutC layout_D(LayoutC::packed({problem.m(), problem.n()}).stride(0)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); + cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); + cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); + + cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + options.alpha, + view_A, + GemmBatched::kTransformA, + view_B, + GemmBatched::kTransformB, + options.beta, + view_C, + view_Ref_device, + ElementAccumulator(0) + ); + + // Copy to host memory + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get(), matrix_D.size()); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); + + // Print out the results and reference in 5D Tensor + // [options.problem_each.m(), options.problem_each.n()] -> [T0, T1, T2, T3, T4]. + // options.problem_each.m() == T0 * T1 + // options.problem_each.n() == T2 * T3 * T4 + // After permute Op, -> [T2, T0, T3, T1, T4]. + int T0 = options.problem_each.m() / T1; + int T4 = options.problem_each.n() / T2 / T3; + + cutlass::TensorView view_D_Tensor(matrix_D.data(), // if LayoutC = cutlass::layout::ColumnMajor, view_D_Tensor should be constructed differently + cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})), cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); + cutlass::TensorView view_Ref_Tensor(matrix_Ref.data(), + cutlass::layout::TensorNDHWC().packed(cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})), cutlass::Tensor5DCoord({T0, T1, T2, T3, T4})); + + // Tensor Permute Op on reference tensor + cutlass::HostTensor view_Ref_Permute_Tensor(cutlass::Tensor5DCoord({T2, T0, T3, T1, T4})); + for (int n = 0; n < T0; ++n) { + for (int d = 0; d < T1; ++d) { + for (int h = 0; h < T2; ++h) { + for (int w = 0; w < T3; ++w) { + for (int c = 0; c < T4; ++c) { + view_Ref_Permute_Tensor.at({h, n, w, d, c}) = view_Ref_Tensor.at({n, d, h, w, c}); // permute([2,0,3,1,4]) + } + } + } + } + } + + // Reference check + passed = cutlass::reference::host::TensorEquals(view_Ref_Permute_Tensor.host_view(), view_D_Tensor); + + if (!passed) { + std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; + return passed; + } + + std::cout << "Passed verification" << std::endl; + return passed; +} + +public: + /// Executes a conventional batched GEMM kernel. + Result profile_batched_kBatched() { + + std::cout << "\n====================================================" << std::endl; + std::cout << "Batched GEMM (CUTLASS):\n" + << "====================================================" << std::endl; + + if (options.verbose) { + print_BMM_info_(); + } + + Result result; + + result.passed = false; + + // Initialize the problem + initialize_(options.batch_count); + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + + // Please make sure all problem_sizes are the same for kBatched mode + auto problem = options.problem_each; + + // For regular BMM + int64_t batch_stride_C = problem.m() * problem.n(); + // For BMM permute output ---> make sure to set batch_stride_D to zero for BMM permute op + int64_t batch_stride_D = 0; + + // Configure GEMM arguments + typename GemmBatched::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kBatched, + options.problem_each, + options.batch_count, + epilogue_op, + (void*)block_A.get(), + (void*)block_B.get(), + (void*)block_C.get(), + (void*)block_D.get(), + problem.m() * problem.k(), + problem.n() * problem.k(), + batch_stride_C, + batch_stride_D, + problem.k(), + problem.n(), + problem.n(), + problem.n() + }; + + // Initialize the GEMM object + GemmBatched gemm; + + result.status = gemm.initialize(arguments); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; + return result; + } + + // Run the batched GEMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (options.reference_check) { + result.passed = verify_BMM_(); + } + + // + // Warm-up run of the batched GEMM object + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + std::cout << " " << 1 << " batched GEMMs launched\n"; + + std::cout << std::endl; + std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms\n"; + std::cout << " " << "Batched GFLOPs: " << result.gflops << "\n"; + + return result; + } + + Result profile_GEMM_permute() { + + std::cout << "\n====================================================" << std::endl; + std::cout << "Normal GEMM (CUTLASS):\n" + << "====================================================" << std::endl; + + if (options.verbose) { + print_GEMM_info_(); + } + + Result result; + + result.passed = false; + + // Initialize the problem + initialize_(1); + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + + // Please make sure all problem_sizes are the same for kBatched mode + auto problem = options.problem_each; + + // Configure GEMM arguments + typename GemmPermute::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + options.problem_each, + 1, + epilogue_op, + (void*)block_A.get(), + (void*)block_B.get(), + (void*)block_C.get(), + (void*)block_D.get(), + 0, + 0, + 0, + 0, + problem.k(), + problem.n(), + problem.n(), + problem.n() + }; + + // Initialize the GEMM object + GemmPermute gemm_normal; + + result.status = gemm_normal.initialize(arguments); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Batched GEMM kernel." << std::endl; + return result; + } + + // Run the normal GEMM object + result.status = gemm_normal.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (options.reference_check) { + result.passed = verify_GEMM_normal_(); + } + + // + // Warm-up run of the normal GEMM object + // + result.status = gemm_normal.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Batched GEMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + gemm_normal(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + std::cout << std::endl; + std::cout << " " << "Normal Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Normal GFLOPs: " << result.gflops << "\n"; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Define the GEMM types + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + // + // Define a conventional batched GEMM type + // + + // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 + using GemmBatched = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, LayoutA, + cutlass::half_t, LayoutB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + AlignmentC, //128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, + 8, /*alignmentA*/ + 8, /*alignmengB*/ + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + false, /*GatherA*/ + false, /*GatherB*/ + false, /*ScatterD*/ + cutlass::layout::Tensor4DPermuteBMM0213 /*PermuteDLayout*/ + >; + + // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 + using GemmPermute = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, LayoutA, + cutlass::half_t, LayoutB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + AlignmentC, //128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, + 8, /*alignmentA*/ + 8, /*alignmengB*/ + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + false, /*GatherA*/ + false, /*GatherB*/ + false, /*ScatterD*/ + cutlass::layout::Tensor5DPermute20314 /*PermuteDLayout*/ + >; + + // + // Profile it + // + + Testbed testbed(options); + + Result result; + result = testbed.profile_batched_kBatched(); + if (!result.passed) { + std::cout << "Profiling batched GEMM has failed.\n"; + std::cout << "\nFailed\n"; + } else { + std::cout << "\nPassed CUTLASS batched GEMM\n"; + } + + result = testbed.profile_GEMM_permute(); + if (!result.passed) { + std::cout << "Profiling normal GEMM has failed.\n"; + std::cout << "\nFailed\n"; + } else { + std::cout << "\nPassed CUTLASS normal GEMM\n"; + } + + std::cout << "\n====================================================" << std::endl; + std::cout << "Finished\n"; + std::cout << "====================================================" << std::endl; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md new file mode 100644 index 00000000..9f9b8cac --- /dev/null +++ b/examples/40_cutlass_py/README.md @@ -0,0 +1,162 @@ +# CUTLASS Python Interface Example + +## Using Docker +You can run the PyCUTLASS on NGC pytorch container. +```shell +docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.08-py3 +``` +PyCUTLASS requires additional dependency Boost C++ library, which can be installed with +```bash +apt-get update +apt-get -y install libboost-all-dev +``` + + +## Install the Python Interface +The source code for python interface is allocated at `tools/library/script/pycutlass`. It requires two environment variables: +* `CUTLASS_PATH`: the root directory of CUTLASS +* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed + +After setting these two environment variables, PyCUTLASS can be installed with +```shell +cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh +``` +*** + +## Troubleshooting + +### Issue 1: permission denied +Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission. + +### Issue 2: rmm: module not found +PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps: +```shell +cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm + +cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python +python setup.py build_ext --inplace +python setup.py install +``` +To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues. + +*** +For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message. +## GEMM Examples +The GEMM examples use numpy to create input tensors and verify the results. +### GEMM F64 Example +Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16 +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 +``` + +### GEMM F32 Example +Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32 +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 +``` +Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4 +``` + +### GEMM F16 Example +Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32 +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 +``` +Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3 +``` + +### GEMM BF16 Example +Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5 +``` + +### GEMM Int8 Example +Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128 +```python +python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1 +``` +*** +## GEMM Grouped Examples +The GEMM Grouped examples use numpy to create input tensors and verify the results. + +Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule +```python +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device +``` +Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule +```python +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle2 -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host +``` +Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule +```python +python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +``` +Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule +```python +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle8 -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +``` +*** +## Conv2d Example +The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/#:~:text=Aid%20to%20Ukraine.-,INSTALL%20PYTORCH,-Select%20your%20preferences). +### Conv2d F32 Fprop +Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0 +``` +Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 +```python +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0 +``` +### Conv2d F32 Wgrad +Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 +```python +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +### Conv2d F32 Dgrad +Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` + +### Conv2d F16 Fprop +Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py new file mode 100644 index 00000000..687cfdc4 --- /dev/null +++ b/examples/40_cutlass_py/conv2d.py @@ -0,0 +1,277 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +import pycutlass +from pycutlass import * +from pycutlass.conv2d_operation import * +from pycutlass.utils import reference_model + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="simt", type=str, + choices=["Simt", 'TensorOp'], + help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignement of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorC32RSK32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], + help='Data type of computation in the epilogue') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", + "HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4", + "StridedDgradHorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU") +# conv related +parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'], + help="The type of convolution: forward propagation (fprop), \ + gradient of activation (dgrad), gradient of weight (wgrad)") +parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"], + ) +parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str, + choices=["analytic", "optimized", "fixed_channels", "few_channels"], + help="This option describes iterator algorithm") + +# arguments +parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"], + help="Split K Mode. Serial is used for non-splitK or serial-splitK.\ + Parallel is used for parallel splitK.") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)") +parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)") +parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)") +parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)") +parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") + +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +element_a = getattr(cutlass, args.element_a) +element_b = getattr(cutlass, args.element_b) +element_c = getattr(cutlass, args.element_c) +element_acc = getattr(cutlass, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst, args.compute_capability, args.compute_capability +) + +layout_a = getattr(cutlass, args.layout_a) +layout_b = getattr(cutlass, args.layout_b) +layout_c = getattr(cutlass, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass, args.element_epilogue) +epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor) +iterator_algorithm = getattr(cutlass.conv.IteratorAlgorithm, args.iterator_algorithm) +swizzling_functor = getattr(cutlass, args.swizzling_functor) +stride_support = getattr(StrideSupport, args.stride_support) +conv_kind = getattr(cutlass.conv.Operator, args.conv_kind) + +operation = Conv2dOperation( + conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, stride_support=stride_support, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation,] + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + reduction_operation = ReductionOperation( + shape=cutlass.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +problem_size = cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), + cutlass.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), + cutlass.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), + cutlass.MatrixCoord(args.stride[0], args.stride[1]), + cutlass.MatrixCoord(args.dilation[0], args.dilation[1]), + cutlass.conv.Mode.cross_correlation, + args.split_k_slices, 1 +) + + +# User-provide inputs +tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size( + conv_kind, problem_size +) +tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size( + conv_kind, problem_size +) +tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size +) + +if args.element_a != "int8": + tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2) + +if args.element_b != "int8": + tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2) + +if args.element_c != "int8": + tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2) + +tensor_D = torch.ones_like(tensor_C) + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, A=tensor_A, + B=tensor_B, C=tensor_C, D=tensor_D, + output_op = LinearCombinationFunctorArguments(args.alpha, args.beta), + split_k_mode=getattr(cutlass.conv.SplitKMode, args.split_k_mode), + split_k_slices=problem_size.split_k_slices +) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) + reduction_arguments = ReductionArguments( + reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], + partitions=problem_size.split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op = LinearCombinationFunctorArguments(args.alpha, args.beta) + ) + +operation.run(arguments) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +reference_model = Conv2dReferenceModule(A, B, C, conv_kind) + +tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta) + +assert torch.equal(tensor_D, tensor_D_ref) + +print("Passed.") diff --git a/examples/40_cutlass_py/gemm.py b/examples/40_cutlass_py/gemm.py new file mode 100644 index 00000000..8341d10d --- /dev/null +++ b/examples/40_cutlass_py/gemm.py @@ -0,0 +1,266 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +import numpy as np +import pycutlass +from pycutlass import * +import cutlass +from bfloat16 import bfloat16 + +import argparse + + +# parse the arguments +parser = argparse.ArgumentParser( + description="Launch CUTLASS GEMM kernels from python: 'D = alpha * A * B + beta * C'") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="simt", type=str, + choices=["Simt", 'TensorOp'], + help="This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM") +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignement of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU") + +# Argument +parser.add_argument("-p", "--problem_size", + default=[128, 128, 128], nargs=3, type=int, + help="GEMM problem size M, N, K") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, + help="Scaling factor of A * B") +parser.add_argument("-beta", "--beta", default=0.0, type=float, + help="Scaling factor of C") +parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str, + choices=["Gemm", "GemmSplitKParallel"], + help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \ + GemmSplitKParallel is used for parallel splitK") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") + +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +# parser.add_argument('-h', '--help', action="store_true", +# help="print help information") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +element_a = getattr(cutlass, args.element_a) +element_b = getattr(cutlass, args.element_b) +element_c = getattr(cutlass, args.element_c) +element_acc = getattr(cutlass, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst, args.compute_capability, args.compute_capability +) + +layout_a = getattr(cutlass, args.layout_a) +layout_b = getattr(cutlass, args.layout_b) +layout_c = getattr(cutlass, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass, args.element_epilogue) +epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor) +swizzling_functor = getattr(cutlass, args.swizzling_functor) + +operation = GemmOperationUniversal( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +if args.gemm_mode == "GemmSplitKParallel": + reduction_operation = ReductionOperation( + shape=cutlass.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +# User-provide inputs + +problem_size = cutlass.gemm.GemmCoord( + args.problem_size[0], args.problem_size[1], args.problem_size[2]) + +if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(bfloat16) + else: + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(getattr(np, args.element_a)) +else: + tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.k(),)).astype(getattr(np, args.element_a)) + +if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(getattr(np, args.element_b)) +else: + tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() + * problem_size.n(),)).astype(getattr(np, args.element_b)) + +if args.element_c != "int8": + if args.element_c == "bfloat16": + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.n(),))).astype(getattr(np, args.element_c)) +else: + tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.n(),)).astype(getattr(np, args.element_c)) + +tensor_D = np.ones_like(tensor_C) + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=LinearCombinationFunctorArguments(args.alpha, args.beta), + gemm_mode=getattr(cutlass.gemm.Mode, args.gemm_mode), + split_k_slices=args.split_k_slices +) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_arguments = ReductionArguments( + operation=reduction_operation, + problem_size=[problem_size.m(), problem_size.n()], + partitions=args.split_k_slices, workspace=arguments.ptr_D, + destination=tensor_D, source=tensor_C, + output_op=LinearCombinationFunctorArguments(args.alpha, args.beta) + ) + +operation.run(arguments) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +# run the host reference module +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta) + +assert np.array_equal(tensor_D, tensor_D_ref) + +print("Passed.") diff --git a/examples/40_cutlass_py/gemm_grouped.py b/examples/40_cutlass_py/gemm_grouped.py new file mode 100644 index 00000000..e26ecc97 --- /dev/null +++ b/examples/40_cutlass_py/gemm_grouped.py @@ -0,0 +1,248 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +import pycutlass +from pycutlass import * +import csv + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser( + description="Launch CUTLASS GEMM Grouped kernels from python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="simt", type=str, + choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU") +# precompute mode +parser.add_argument("-pm", "--precompute_mode", + default="Device", type=str, choices=["Host", "Device"], + help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)") +# arguments +parser.add_argument("-p", "--problem_size_dir", type=str, + help="path to the csv file contains the problem sizes") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") + +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +element_a = getattr(cutlass, args.element_a) +element_b = getattr(cutlass, args.element_b) +element_c = getattr(cutlass, args.element_c) +element_acc = getattr(cutlass, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst, args.compute_capability, args.compute_capability +) + +layout_a = getattr(cutlass, args.layout_a) +layout_b = getattr(cutlass, args.layout_b) +layout_c = getattr(cutlass, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass, args.element_epilogue) +epilogue_functor = getattr(EpilogueFunctor, args.epilogue_functor) +swizzling_functor = getattr(cutlass, args.swizzling_functor) +precompute_mode = getattr(SchedulerMode, args.precompute_mode) + +operation = GemmOperationGrouped( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + precompute_mode=precompute_mode +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +pycutlass.compiler.add_module([operation, ]) + +reference_module = ReferenceModule(A, B, C) + +# get problems +problem_sizes = [] +with open(args.problem_size_dir) as csv_file: + reader = csv.reader(csv_file) + for row in reader: + problem_sizes.append( + cutlass.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) + ) + +problem_count = len(problem_sizes) + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +problem_sizes_coord = [] +tensor_D_refs = [] + +for problem_size in problem_sizes: + if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(bfloat16) + else: + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(getattr(np, args.element_a)) + else: + tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.k(),)).astype(getattr(np, args.element_a)) + + if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(getattr(np, args.element_b)) + else: + tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() + * problem_size.n(),)).astype(getattr(np, args.element_b)) + + if args.element_c != "int8": + if args.element_c == "bfloat16": + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.n(),))).astype(getattr(np, args.element_c)) + else: + tensor_C = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.n(),)).astype(getattr(np, args.element_c)) + tensor_D = np.zeros_like(tensor_C) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + tensor_D_refs.append(reference_module.run( + tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta)) + problem_sizes_coord.append(problem_size) + +arguments = GemmGroupedArguments( + operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=LinearCombinationFunctorArguments(args.alpha, args.beta) +) + +operation.run(arguments) + +arguments.sync() + +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + assert np.array_equal(tensor_d, tensor_d_ref) + +print("Passed.") diff --git a/examples/40_cutlass_py/grouped_gemm_problem_size.csv b/examples/40_cutlass_py/grouped_gemm_problem_size.csv new file mode 100644 index 00000000..d1d0dd00 --- /dev/null +++ b/examples/40_cutlass_py/grouped_gemm_problem_size.csv @@ -0,0 +1,3 @@ +128,128,128 +128,128,256 +512,128,384 diff --git a/examples/40_cutlass_py/test-cutlass-py.py b/examples/40_cutlass_py/test-cutlass-py.py deleted file mode 100644 index e1ee636b..00000000 --- a/examples/40_cutlass_py/test-cutlass-py.py +++ /dev/null @@ -1,169 +0,0 @@ - -# System modules -import numpy as np -import os.path -import sys -import ctypes - -# CUDA Python modules -from cuda import cuda -from cuda import nvrtc - -# CUTLASS modules -import library -import manifest as cutlass_manifest -import generator -import rt - - -# -# Construct an SGEMM -# - -manifest = cutlass_manifest.Manifest() - -generator.GenerateSM50_Simt(manifest, "11.5.0") - -# -# Construct a GEMM operation -# - -operation = manifest.operations_by_name['cutlass_simt_sgemm_128x128_8x2_nt_align1'] - -# -# Construct a runtime GEMM operation -# -gemm = rt.Gemm(operation) - -# -# Initialize context -# -err, = cuda.cuInit(0) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, device = cuda.cuDeviceGet(0) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, context = cuda.cuCtxCreate(0, device) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -# -# Construct a module -# - -architectures = [80,] -include_paths = [ - '../../include', - '../../tools/util/include', -] - -compilation_options = rt.CompilationOptions(architectures, include_paths) - -module = rt.Module('module.cu', [gemm], compilation_options) - -# -# Setup a workspace -# - -M, N, K = (128, 128, 128) - -tensor_A = np.ndarray(M * K, dtype=np.float32) -tensor_B = np.ndarray(N * K, dtype=np.float32) -tensor_C = np.ndarray(M * N, dtype=np.float32) -tensor_D = np.ndarray(M * N, dtype=np.float32) - -err, tensor_A_d = cuda.cuMemAlloc(tensor_A.size * tensor_A.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, tensor_B_d = cuda.cuMemAlloc(tensor_B.size * tensor_B.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, tensor_C_d = cuda.cuMemAlloc(tensor_C.size * tensor_C.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, tensor_D_d = cuda.cuMemAlloc(tensor_D.size * tensor_D.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, stream = cuda.cuStreamCreate(0) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -tensors = [ - (tensor_A_d, tensor_A), - (tensor_B_d, tensor_B), - (tensor_C_d, tensor_C), - (tensor_D_d, tensor_D) -] - -for tensor_device, tensor_host in tensors: - bytes = tensor_host.size * tensor_host.itemsize - print("Tensor has dimensions: %s (%d bytes)" % (str(tensor_host.size), tensor_host.itemsize)) - err, = cuda.cuMemcpyHtoDAsync(tensor_device, tensor_host, bytes, stream) - print("updating tensor in device memory ", hex(int(tensor_device))) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('CUDA Error %s' % str(err)) - -# -# Initialize a host buffer -# - -arguments = rt.GemmArguments() - -arguments.problem_size = rt.GemmCoord(M, N, K) - -arguments.A = rt.TensorRef(tensor_A_d, M) -arguments.B = rt.TensorRef(tensor_B_d, N) -arguments.C = rt.TensorRef(tensor_C_d, M) -arguments.D = rt.TensorRef(tensor_D_d, M) - -host_workspace = bytearray(gemm.get_host_workspace_size(arguments)) -device_workspace = None - -launch_config = gemm.plan(arguments) - -byte_count = gemm.initialize(host_workspace, device_workspace, launch_config, arguments) - -# -# Launch the kernel -# - -err = gemm.run(host_workspace, device_workspace, launch_config) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('CUDA Error %s' % str(err)) - -# -# Verify results -# -err, = cuda.cuStreamSynchronize(stream) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - - -# -# Debug reporting of byte array contents -# - -def PrintBytearray(host_workspace): - uint_str = None - prefix = None - print("uint32_t host_workspace[] = {") - for idx, byte in enumerate(host_workspace): - if not (idx % 4): - if uint_str is not None: - print(prefix, uint_str, ",") - prefix = "/* offset: %d B */ 0x" % idx - uint_str = "" - uint_str = "{:02x}".format(byte) + uint_str - print("};") diff --git a/examples/41_multi_head_attention/CMakeLists.txt b/examples/41_multi_head_attention/CMakeLists.txt new file mode 100644 index 00000000..442048f6 --- /dev/null +++ b/examples/41_multi_head_attention/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 41_multi_head_attention + fused_multihead_attention.cu + ) + diff --git a/examples/41_multi_head_attention/fused_multihead_attention.cu b/examples/41_multi_head_attention/fused_multihead_attention.cu new file mode 100644 index 00000000..455e6284 --- /dev/null +++ b/examples/41_multi_head_attention/fused_multihead_attention.cu @@ -0,0 +1,1145 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Attention Example. + + This workload computes an attention example with non-fixed sequence length input. Pointers of arrays + are fed into grouped-GEMM functions fused with softmax for computation. + + Examples: + + # Run an attention example with default setup (max sequence length = 1024, batch size = 16, head size = 64, head number = 12) + $ ./examples/41_multi_head_attention/41_multi_head_attention + + # Run an attention example with batch size = 64 and head number = 16 without checking the correctness + $ ./examples/41_multi_head_attention/41_multi_head_attention --head_number=16 --batch_size=64 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" +#include "cutlass/fast_math.h" +#include "gemm_attention.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool use_mask; + + std::vector problem_sizes0; + std::vector problem_sizes1; + + std::vector problem_sizes0_real; + std::vector problem_sizes1_real; + + int alignment; + int head_number; + int batch_size; + int head_size; + int seq_length; + int iterations; + int cuda_streams; + + // alpha0, alpha1 and beta are fixed + // in this multi-head attention example + float alpha0; + float alpha1; + float beta; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(16), + reference_check(true), + head_number(12), + batch_size(16), + head_size(64), + seq_length(1024), + use_mask(false), + iterations(20), + cuda_streams(0) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 16); + cmd.get_cmd_line_argument("head_number", head_number, 12); + cmd.get_cmd_line_argument("batch_size", batch_size, 16); + cmd.get_cmd_line_argument("head_size", head_size, 64); + cmd.get_cmd_line_argument("seq_length", seq_length, 1024); + cmd.get_cmd_line_argument("use_mask", use_mask, false); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + + randomize_problems(); + + } + + void randomize_problems() { + + int problem_count = head_number * batch_size; + + problem_sizes0.reserve(problem_count); + problem_sizes1.reserve(problem_count); + + // When using mask, the original inputs are not padded + // and we need to save these info. + if (use_mask) { + problem_sizes0_real.reserve(problem_count); + problem_sizes1_real.reserve(problem_count); + } + + for (int i = 0; i < batch_size; ++i) { + // problems belonging to the same batch share the same seq len + int m_real = (rand() % seq_length); + int m = (m_real + 1 + alignment - 1) / alignment * alignment; + int n = m; + int k = head_size; + + for (int j = 0; j < head_number; ++j) { + cutlass::gemm::GemmCoord problem0(m, n, k); + cutlass::gemm::GemmCoord problem1(m, k, n); + problem_sizes0.push_back(problem0); + problem_sizes1.push_back(problem1); + + if (use_mask) { + cutlass::gemm::GemmCoord problem0_real(m_real, m_real, k); + cutlass::gemm::GemmCoord problem1_real(m_real, k, m_real); + problem_sizes0_real.push_back(problem0_real); + problem_sizes1_real.push_back(problem1_real); + } + + } + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "41_multi_head_attention\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" + << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" + << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" + << " --seq_length= Max sequence length in multi-head attention (default: --seq_length=1024)\n" + << " --use_mask= If true, performs padding-like masking in softmax.\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + for (auto const & problem : problem_sizes0) { + // Two flops per multiply-add + fmas += problem.product() * 2; + } + + // Multiply another '2' because of the back-to-back GEMM problems in attention + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TestbedAttention { +public: + + // + // Type definitions + // + + using ElementQ = typename Attention::ElementQ; + using ElementK = typename Attention::ElementK; + using ElementP = typename Attention::ElementP; + using ElementAccumulator = typename Attention::GemmGrouped0::ElementAccumulator; + using ElementV = typename Attention::ElementV; + using ElementO = typename Attention::ElementOutput; + + using EpilogueOutputOp = typename Attention::GemmGrouped0::GemmKernel::EpilogueVisitor::ElementwiseFunctor; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using ElementNorm = typename Attention::ElementNorm; + using ElementSum = typename Attention::ElementSum; + using ElementSoftmaxCompute = typename Attention::ElementSoftmaxCompute; + + using LayoutQ = typename Attention::LayoutQ; + using LayoutK = typename Attention::LayoutK; + using LayoutP = typename Attention::LayoutP; + using LayoutV = typename Attention::LayoutV; + using LayoutO = typename Attention::LayoutO; + + using MatrixCoord = typename LayoutP::TensorCoord; + + using ProblemVisitor0 = typename Attention::GemmKernel0::ProblemVisitor; + using ProblemVisitor1 = typename Attention::GemmKernel1::ProblemVisitor; + +private: + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_Q; + cutlass::Distribution::Kind init_K; + cutlass::Distribution::Kind init_P; + cutlass::Distribution::Kind init_V; + cutlass::Distribution::Kind init_O; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device0; + cutlass::DeviceAllocation problem_sizes_device1; + cutlass::DeviceAllocation problem_sizes_device0_real; + + std::vector offset_Q; + std::vector offset_K; + std::vector offset_P; + std::vector offset_V; + std::vector offset_O; + std::vector offset_Norm; + std::vector offset_Sum; + + std::vector ldq_host; + std::vector ldk_host; + std::vector ldp_host; + std::vector ldv_host; + std::vector ldo_host; + std::vector seqlen_host; + + cutlass::DeviceAllocation ldq; + cutlass::DeviceAllocation ldk; + cutlass::DeviceAllocation ldp; + cutlass::DeviceAllocation ldv; + cutlass::DeviceAllocation ldo; + cutlass::DeviceAllocation seqlen; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_P; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_Norm; + cutlass::DeviceAllocation block_Sum; + + cutlass::DeviceAllocation offset_P_Device; + cutlass::DeviceAllocation offset_Norm_Device; + cutlass::DeviceAllocation offset_Sum_Device; + + cutlass::DeviceAllocation ptr_Q; + cutlass::DeviceAllocation ptr_K; + cutlass::DeviceAllocation ptr_P; + cutlass::DeviceAllocation ptr_V; + cutlass::DeviceAllocation ptr_O; + cutlass::DeviceAllocation ptr_Max; + cutlass::DeviceAllocation ptr_Sum; + +public: + + // + // Methods + // + + TestbedAttention( + Options &options_, + cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { } + + int problem_count() const { + return (options.head_number * options.batch_size); + } + +private: + + /// Helper to initialize a tensor view + template + void initialize_tensor_( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 8; + scope_min = -8; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Initializes data structures + void initialize_() { + + // + // Set scalors for the mha example + // + + options.alpha0 = 1.0f / sqrt(float(options.head_size)); + options.alpha1 = 1.0f; + options.beta = 0; + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_Q = 0; + int64_t total_elements_K = 0; + int64_t total_elements_P = 0; + int64_t total_elements_V = 0; + int64_t total_elements_O = 0; + + int64_t total_elements_partial_norm = 0; + + ldq_host.resize(problem_count()); + ldk_host.resize(problem_count()); + ldp_host.resize(problem_count()); + ldv_host.resize(problem_count()); + ldo_host.resize(problem_count()); + seqlen_host.resize(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + + auto problem = options.problem_sizes0.at(i); + + ldq_host.at(i) = LayoutQ::packed({problem.m(), problem.k()}).stride(0); + ldk_host.at(i) = LayoutK::packed({problem.k(), problem.n()}).stride(0); + ldp_host.at(i) = LayoutP::packed({problem.m(), problem.n()}).stride(0); + ldv_host.at(i) = LayoutV::packed({problem.n(), problem.k()}).stride(0); + ldo_host.at(i) = LayoutO::packed({problem.m(), problem.k()}).stride(0); + + // m = n for attention problems. + int64_t non_leading_dim = ldp_host.at(i); + int64_t threadblock_n = Attention::GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape::kN; + int64_t threadblock_num = (ldp_host.at(i) + threadblock_n - 1) / threadblock_n; + + seqlen_host.at(i) = problem.m(); + + offset_Q.push_back(total_elements_Q); + offset_K.push_back(total_elements_K); + offset_P.push_back(total_elements_P); + offset_V.push_back(total_elements_V); + offset_O.push_back(total_elements_O); + offset_Norm.push_back(total_elements_partial_norm); + offset_Sum.push_back(total_elements_partial_norm); + + int64_t elements_Q = problem.m() * problem.k(); + int64_t elements_K = problem.k() * problem.n(); + int64_t elements_P = problem.m() * problem.n(); + int64_t elements_V = problem.n() * problem.k(); + int64_t elements_O = problem.m() * problem.k(); + int64_t elements_norm = non_leading_dim * threadblock_num; + + total_elements_Q += elements_Q; + total_elements_K += elements_K; + total_elements_P += elements_P; + total_elements_V += elements_V; + total_elements_O += elements_O; + total_elements_partial_norm += elements_norm; + + } + + problem_sizes_device0.reset(problem_count()); + problem_sizes_device1.reset(problem_count()); + problem_sizes_device0.copy_from_host(options.problem_sizes0.data()); + problem_sizes_device1.copy_from_host(options.problem_sizes1.data()); + + if (options.use_mask) { + problem_sizes_device0_real.reset(problem_count()); + problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data()); + } + + ldq.reset(problem_count()); + ldk.reset(problem_count()); + ldp.reset(problem_count()); + ldv.reset(problem_count()); + ldo.reset(problem_count()); + seqlen.reset(problem_count()); + + ldq.copy_from_host(ldq_host.data()); + ldk.copy_from_host(ldk_host.data()); + ldp.copy_from_host(ldp_host.data()); + ldv.copy_from_host(ldv_host.data()); + ldo.copy_from_host(ldo_host.data()); + seqlen.copy_from_host(seqlen_host.data()); + + // + // Assign pointers + // + + block_Q.reset(total_elements_Q); + block_K.reset(total_elements_K); + block_P.reset(total_elements_P); + block_V.reset(total_elements_V); + block_O.reset(total_elements_O); + block_Norm.reset(total_elements_partial_norm); + block_Sum.reset(total_elements_partial_norm); + + offset_P_Device.reset(problem_count()); + offset_Norm_Device.reset(problem_count()); + offset_Sum_Device.reset(problem_count()); + + // sync offset with device + cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); + cutlass::device_memory::copy_to_device(offset_Norm_Device.get(), offset_Norm.data(), offset_Norm.size()); + cutlass::device_memory::copy_to_device(offset_Sum_Device.get(), offset_Sum.data(), offset_Sum.size()); + + std::vector ptr_Q_host(problem_count()); + std::vector ptr_K_host(problem_count()); + std::vector ptr_P_host(problem_count()); + std::vector ptr_V_host(problem_count()); + std::vector ptr_O_host(problem_count()); + std::vector ptr_norm_host(problem_count()); + std::vector ptr_sum_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i); + ptr_K_host.at(i) = block_K.get() + offset_K.at(i); + ptr_P_host.at(i) = block_P.get() + offset_P.at(i); + ptr_V_host.at(i) = block_V.get() + offset_V.at(i); + ptr_O_host.at(i) = block_O.get() + offset_O.at(i); + ptr_norm_host.at(i) = block_Norm.get() + offset_Norm.at(i); + ptr_sum_host.at(i) = block_Sum.get() + offset_Sum.at(i); + } + + ptr_Q.reset(problem_count()); + ptr_Q.copy_from_host(ptr_Q_host.data()); + + ptr_K.reset(problem_count()); + ptr_K.copy_from_host(ptr_K_host.data()); + + ptr_P.reset(problem_count()); + ptr_P.copy_from_host(ptr_P_host.data()); + + ptr_V.reset(problem_count()); + ptr_V.copy_from_host(ptr_V_host.data()); + + ptr_O.reset(problem_count()); + ptr_O.copy_from_host(ptr_O_host.data()); + + ptr_Max.reset(problem_count()); + ptr_Max.copy_from_host(ptr_norm_host.data()); + + ptr_Sum.reset(problem_count()); + ptr_Sum.copy_from_host(ptr_sum_host.data()); + + // + // Initialize the problems of the workspace + // + + initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1); + initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2); + initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3); + + } + + template + bool verify_tensor_(std::vector vector_Input, \ + std::vector vector_Input_Ref, + int64_t verify_length = -1) { + + int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); + size = (verify_length == -1) ? size : verify_length; + + // 0.05 for absolute error + float abs_tol = 5e-2f; + // 10% for relative error + float rel_tol = 1e-1f; + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); + float relative_diff = abs_diff / abs_ref; + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { + printf("diff = %f, rel_diff = %f, {%f, %f}.\n", abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + return false; + } + + } + + return true; + } + + /// Verifies the result is a GEMM + bool verify_() { + + bool passed = true; + + for (int32_t i = 0; i < problem_count(); ++i) { + cutlass::gemm::GemmCoord problem = options.problem_sizes0.at(i); + cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); + + LayoutQ layout_Q(ldq_host.at(i)); + LayoutK layout_K(ldk_host.at(i)); + LayoutP layout_P(ldp_host.at(i)); + LayoutV layout_V(ldv_host.at(i)); + LayoutO layout_O(ldo_host.at(i)); + + MatrixCoord extent_Q{problem.m(), problem.k()}; + MatrixCoord extent_K{problem.k(), problem.n()}; + MatrixCoord extent_P{problem.m(), problem.n()}; + MatrixCoord extent_V{problem.n(), problem.k()}; + MatrixCoord extent_O{problem.m(), problem.k()}; + + cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); + cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); + cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); + cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); + + cutlass::DeviceAllocation block_Ref(layout_P.capacity(extent_P)); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_P, extent_P); + + cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); + cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementQ, LayoutQ, + ElementK, LayoutK, + ElementP, LayoutP, + ElementCompute, ElementAccumulator + >( + problem, + ElementAccumulator(options.alpha0), + view_Q, + Attention::GemmGrouped0::kTransformA, + view_K, + Attention::GemmGrouped0::kTransformB, + ElementAccumulator(options.beta), + view_P, + view_Ref_device, + ElementAccumulator(0) + ); + + // Compute softmax for P. We need to explicitly compute softmax + // over P because softmax is fused to the second GEMM in the + // profiled implementation. + std::vector matrix_Ref(layout_P.capacity(extent_P)); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); + cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); + std::vector vector_Norm_Ref(problem.m()); + std::vector vector_Sum_Ref(problem.m()); + + int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem.n(); + + // Compute softmax for referece matrix + // Assumed a row-major storage + for (int m = 0; m < problem.m(); m++) { + ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); + for (int n = 1; n < n_dim; n++) { + max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); + } + + vector_Norm_Ref.at(m) = ElementNorm(max); + + ElementSoftmaxCompute sum = ElementSoftmaxCompute(); + for (int n = 0; n < n_dim; n++) { + sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); + } + ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); + + vector_Sum_Ref.at(m) = ElementSum(inv_sum); + + for (int n = 0; n < n_dim; n++) { + view_Ref_host.ref().at({m, n}) = ElementP( + std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum + ); + } + + } + + // when not using mask, problem_real and problem share the same sizes + if (options.use_mask) { + for (int m = 0; m < problem.m(); m++) { + for (int n = n_dim; n < problem.n(); n++) { + view_Ref_host.ref().at({m, n}) = ElementP(0); + } + } + } + + cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size()); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementP, LayoutP, + ElementV, LayoutV, + ElementO, LayoutO, + ElementCompute, ElementAccumulator + >( + problem1, + ElementAccumulator(options.alpha1), + view_P, + Attention::GemmGrouped0::kTransformA, + view_V, + Attention::GemmGrouped0::kTransformB, + ElementAccumulator(options.beta), + view_Ref_O_device, + view_Ref_O_device, + ElementAccumulator(0) + ); + + // Copy to host memory + + int64_t threadblock_n = Attention::GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape::kN; + int64_t threadblock_num = (problem.m() + threadblock_n - 1) / threadblock_n; + + std::vector vector_Norm(problem.m() * threadblock_num); + std::vector vector_Sum(problem.m() * threadblock_num); + + cutlass::device_memory::copy_to_host(vector_Norm.data(), block_Norm.get() + offset_Norm.at(i), vector_Norm.size()); + cutlass::device_memory::copy_to_host(vector_Sum.data(), block_Sum.get() + offset_Sum.at(i), vector_Sum.size()); + + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); + + std::vector matrix_O(layout_O.capacity(extent_O)); + cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); + std::vector matrix_Ref_O(layout_O.capacity(extent_O)); + cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); + + bool verified_N = false; + bool verified_S = false; + bool verified_O = false; + + if (!verified_N) { + verified_N = verify_tensor_(vector_Norm, vector_Norm_Ref); + } + + if (!verified_S) { + verified_S = verify_tensor_(vector_Sum, vector_Sum_Ref); + } + + + if (!verified_O) { + verified_O = verify_tensor_(matrix_O, matrix_Ref_O); + } + + passed = passed && verified_N && verified_S && verified_O; + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + + if (!verified_O) { + std::cout << "Final matrix output is incorrect" << std::endl; + } + + if (!verified_N) { + std::cout << "Max is incorrect" << std::endl; + } + + if (!verified_S) { + std::cout << "Sum is incorrect" << std::endl; + } + + return passed; + } + + } + + return passed; + } + +public: + + /// Returns the number of threadblocks to launch if the kernel can run on the target + /// device. Otherwise, returns zero. + int sufficient() const { + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + int occupancy = Attention::GemmGrouped0::maximum_active_blocks(); + + return properties.multiProcessorCount * occupancy; + + } + + + /// Executes a CUTLASS Attention kernel and measures runtime. + Result profile_grouped() { + + Result result; + + int threadblock_count = sufficient(); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Attention kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + initialize_(); + + typename Attention::Arguments args( + problem_sizes_device0.get(), + problem_sizes_device1.get(), + problem_count(), + threadblock_count, + ptr_Q.get(), + ptr_K.get(), + ptr_P.get(), + ptr_V.get(), + ptr_O.get(), + ptr_Max.get(), + ptr_Sum.get(), + block_P.get(), + block_Norm.get(), + block_Sum.get(), + offset_P_Device.get(), + offset_Norm_Device.get(), + offset_Sum_Device.get(), + ldq.get(), + ldk.get(), + ldp.get(), + ldv.get(), + ldo.get(), + ElementAccumulator(options.alpha0), + ElementAccumulator(options.alpha1), + ElementAccumulator(options.beta), + options.head_number, + options.batch_size, + options.seq_length, + options.problem_sizes0.data(), + options.problem_sizes1.data(), + problem_sizes_device0_real.get() + ); + + size_t workspace_size0 = ProblemVisitor0::kRequiresPrecomputation ?\ + ProblemVisitor0::get_workspace_size(options.problem_sizes0.data(),\ + problem_count(),\ + threadblock_count)\ + : 0; + + size_t workspace_size1 = ProblemVisitor1::kRequiresPrecomputation ?\ + ProblemVisitor1::get_workspace_size(options.problem_sizes1.data(),\ + problem_count(),\ + threadblock_count)\ + : 0; + + cutlass::DeviceAllocation workspace0(workspace_size0); + cutlass::DeviceAllocation workspace1(workspace_size1); + + Attention attention; + + result.status = attention.initialize(args, workspace0.get(), workspace1.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Attention kernel." << std::endl; + return result; + } + + result.status = attention.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Attention kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (options.reference_check) { + result.passed = verify_(); + } + + // + // Warm-up run of the grouped GEMM object + // + + result.status = attention.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Attention kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + attention(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + std::cout << std::endl; + std::cout << "CUTLASS Attention:\n" + << "====================================================" << std::endl; + std::cout << " " << " {max sequence length, head size, head number, batch size} = {" << options.seq_length \ + << ", " << options.head_size << ", " << options.head_number << ", " << options.batch_size << "}." << std::endl; + std::cout << std::endl; + std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "GFLOPs: " << result.gflops << std::endl; + + return result; + } + + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Define the CUTLASS Attention type + // + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using ElementQ = cutlass::half_t; + using ElementK = cutlass::half_t; + using ElementP = ElementOutput; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutP = cutlass::layout::RowMajor; + + static bool const UseMask = false; + + if (UseMask != options.use_mask) { + std::cerr << "UseMask and user-defined use_mask need to be consistant, " + << " aborted execution.\n"; + return -2; + } + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + + using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<32, 32, 32>; + + static int const Stages0 = 3; + static int const Stages1 = 4; + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Attention = cutlass::FusedMultiHeadAttention< + ElementQ, + LayoutQ, + ElementK, + LayoutK, + ElementP, + LayoutP, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + Stages0, + Stages1, + UseMask + >; + + // + // Test and profile + // + + TestbedAttention testbed(options); + + if (!testbed.sufficient()) { + std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; + return 0; + } + + Result result = testbed.profile_grouped(); + if (!result.passed) { + std::cout << "Profiling CUTLASS attention has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + std::cout << "\nPassed\n"; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_multi_head_attention/gemm_attention.h b/examples/41_multi_head_attention/gemm_attention.h new file mode 100644 index 00000000..9990c0fb --- /dev/null +++ b/examples/41_multi_head_attention/gemm_attention.h @@ -0,0 +1,626 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Defines the FusedMultiHeadAttention Class + + The class contains the following: + 1) GEMM0 with epilogue fusion, + 2) GEMM1 with mainloop fusion, and + 3) A lightweight full softmax reduction kernel. + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h" +#include "cutlass/reduction/kernel/reduce_softmax_final.h" +#include "gemm_grouped_with_softmax_visitor.h" + +namespace cutlass { + +template < + typename ElementQ_, + typename LayoutQ_, + typename ElementK_, + typename LayoutK_, + typename ElementP_, + typename LayoutP_, + typename ElementCompute_, + typename OperatorClass_, + typename ArchTag_, + typename ThreadblockShape0_, + typename ThreadblockShape1_, + typename WarpShape0_, + typename WarpShape1_, + typename InstructionShape_, + int kStages0_, + int kStages1_, + bool UseMasking_ = false, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute, + int Alignment = 128 / cutlass::sizeof_bits::value, + typename ElementSoftmax_ = ElementP_ +> +class FusedMultiHeadAttention { +public: + + using ElementQ = ElementQ_; + using ElementK = ElementK_; + using ElementP = ElementP_; + using ElementV = ElementK; + using ElementOutput = ElementP; + using ElementAccumulator = ElementCompute_; + + using LayoutQ = LayoutQ_; + using LayoutK = LayoutK_; + using LayoutP = LayoutP_; + using LayoutV = LayoutK; + using LayoutO = LayoutP; + + using ElementNorm = cutlass::half_t; + using ElementSum = cutlass::half_t; + using ElementSoftmaxCompute = float; + using LayoutNorm = cutlass::layout::RowMajor; + + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + + using ThreadblockShape0 = ThreadblockShape0_; + using WarpShape0 = WarpShape0_; + + using ThreadblockShape1 = ThreadblockShape1_; + using WarpShape1 = WarpShape1_; + + static int const Stages0 = kStages0_; + static int const Stages1 = kStages1_; + + using InstructionShape = InstructionShape_; + + using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; + + using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>; + + using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementQ, ElementK, ElementP, + ElementAccumulator>::Operator; + static bool const kInternalTranspose = cutlass::platform::is_same::value; + + static bool const kUseMasking = UseMasking_; + + static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_; + static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_; + + using MapArguments = cutlass::gemm::kernel::detail::MapArguments< + ElementQ, + LayoutQ, + cutlass::ComplexTransform::kNone, + 8, + ElementK, + LayoutK, + cutlass::ComplexTransform::kNone, + 8, + LayoutP, + kInternalTranspose + >; + + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementP, + typename MapArguments::LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape0, + WarpShape0, + InstructionShape, + EpilogueOutputOp0, + ThreadblockSwizzle, + Stages0, + true, + Operator, + cutlass::gemm::SharedMemoryClearOption::kNone + >::GemmKernel; + + using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< + ThreadblockShape0, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + typename EpilogueOutputOp0::ElementCompute, + ElementNorm, + ElementSum, + ElementSoftmaxCompute, + EpilogueOutputOp0, + kUseMasking + >; + + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue + >::Epilogue; + + using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor< + typename DefaultGemmKernel::Mma, + Epilogue, + ThreadblockSwizzle, + kGroupScheduleMode0, + kInternalTranspose, + kUseMasking + >; + + using GemmGrouped0 = cutlass::gemm::device::GemmGrouped; + + using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< + ElementNorm, + ElementSum, + typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute, + typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape, + true + >; + + using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion< + ElementP, + LayoutP, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + ElementV, + LayoutV, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + ElementNorm, + LayoutNorm, + ElementOutput, + LayoutO, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape1, + WarpShape1, + InstructionShape, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages1, + kGroupScheduleMode1 + >::GemmKernel; + + using GemmGrouped1 = cutlass::gemm::device::GemmGrouped; + +public: + + /// Arguments class + struct Arguments { + cutlass::gemm::GemmCoord *problem_sizes0; + cutlass::gemm::GemmCoord *problem_sizes0_real; + cutlass::gemm::GemmCoord *problem_sizes1; + int problem_count; + int threadblock_count; + + ElementQ ** ptr_Q; + ElementK ** ptr_K; + ElementP ** ptr_P; + ElementP ** ptr_V; + ElementP ** ptr_O; + + ElementNorm **ptr_Max; + ElementSum **ptr_Sum; + + ElementP *block_P; + ElementNorm *block_Norm; + ElementSum *block_Sum; + int64_t *offset_P; + int64_t *offset_Norm_Device; + int64_t *offset_Sum_Device; + + typename LayoutQ::Stride::LongIndex *ldq; + typename LayoutK::Stride::LongIndex *ldk; + typename LayoutP::Stride::LongIndex *ldp; + typename LayoutP::Stride::LongIndex *ldv; + typename LayoutP::Stride::LongIndex *ldo; + + cutlass::gemm::GemmCoord *problem_sizes0_host; + cutlass::gemm::GemmCoord *problem_sizes1_host; + + ElementAccumulator alpha0; + ElementAccumulator alpha1; + ElementAccumulator beta; + + int head_number; + int batch_size; + int seq_length; + + typename ApplyFinalReductionDevice::Arguments reduction; + + // + // Methods + // + Arguments(): + problem_count(0), + threadblock_count(0), + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), + block_P(nullptr), + block_Norm(nullptr), + block_Sum(nullptr), + offset_P(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr), + ldq(nullptr), + ldk(nullptr), + ldp(nullptr), + ldv(nullptr), + ldo(nullptr), + head_number(0), + batch_size(0), + seq_length(0) + { + + } + + Arguments( + cutlass::gemm::GemmCoord *problem_sizes0, + cutlass::gemm::GemmCoord *problem_sizes1, + int problem_count, + int threadblock_count, + ElementQ ** ptr_Q, + ElementK ** ptr_K, + ElementP ** ptr_P, + ElementP ** ptr_V, + ElementP ** ptr_O, + ElementNorm **ptr_Max, + ElementSum **ptr_Sum, + ElementP *block_P, + ElementNorm *block_Norm, + ElementSum *block_Sum, + int64_t *offset_P, + int64_t *offset_Norm_Device, + int64_t *offset_Sum_Device, + typename LayoutQ::Stride::LongIndex *ldq, + typename LayoutK::Stride::LongIndex *ldk, + typename LayoutP::Stride::LongIndex *ldp, + typename LayoutP::Stride::LongIndex *ldv, + typename LayoutP::Stride::LongIndex *ldo, + ElementAccumulator alpha0, + ElementAccumulator alpha1, + ElementAccumulator beta, + int head_number, + int batch_size, + int seq_length, + cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr, + cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr, + cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + threadblock_count(threadblock_count), + ptr_Q(ptr_Q), + ptr_K(ptr_K), + ptr_P(ptr_P), + ptr_V(ptr_V), + ptr_O(ptr_O), + ptr_Max(ptr_Max), + ptr_Sum(ptr_Sum), + block_P(block_P), + block_Norm(block_Norm), + block_Sum(block_Sum), + offset_P(offset_P), + offset_Norm_Device(offset_Norm_Device), + offset_Sum_Device(offset_Sum_Device), + ldq(ldq), + ldk(ldk), + ldp(ldp), + ldv(ldv), + ldo(ldo), + alpha0(alpha0), + alpha1(alpha1), + beta(beta), + head_number(head_number), + batch_size(batch_size), + seq_length(seq_length), + problem_sizes0_host(problem_sizes0_host), + problem_sizes1_host(problem_sizes1_host), + problem_sizes0_real(problem_sizes0_real), + reduction( + problem_sizes0, + block_Norm, + block_Sum, + offset_Norm_Device, + offset_Sum_Device + ) + { + + } + + + }; + + struct Params { + cutlass::gemm::GemmCoord *problem_sizes0; + cutlass::gemm::GemmCoord *problem_sizes0_real; + cutlass::gemm::GemmCoord *problem_sizes1; + int problem_count; + int threadblock_count; + + ElementQ ** ptr_Q; + ElementK ** ptr_K; + ElementP ** ptr_P; + ElementP ** ptr_V; + ElementP ** ptr_O; + + ElementNorm **ptr_Max; + ElementSum **ptr_Sum; + + ElementP *block_P; + ElementNorm *block_Norm; + ElementSum *block_Sum; + int64_t *offset_P; + int64_t *offset_Norm_Device; + int64_t *offset_Sum_Device; + + typename LayoutQ::Stride::LongIndex *ldq; + typename LayoutK::Stride::LongIndex *ldk; + typename LayoutP::Stride::LongIndex *ldp; + typename LayoutP::Stride::LongIndex *ldv; + typename LayoutP::Stride::LongIndex *ldo; + + cutlass::gemm::GemmCoord *problem_sizes0_host; + cutlass::gemm::GemmCoord *problem_sizes1_host; + + ElementAccumulator alpha0; + ElementAccumulator alpha1; + ElementAccumulator beta; + + int head_number; + int batch_size; + int seq_length; + + typename ApplyFinalReductionDevice::Params reduction; + + Params(): + problem_count(0), + threadblock_count(0), + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), + block_P(nullptr), + block_Norm(nullptr), + block_Sum(nullptr), + offset_P(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr), + ldq(nullptr), + ldk(nullptr), + ldp(nullptr), + ldv(nullptr), + ldo(nullptr), + problem_sizes0(nullptr), + problem_sizes1(nullptr), + problem_sizes0_real(nullptr), + head_number(0), + batch_size(0), + seq_length(0) + { + + } + + Params(Arguments const &args, void *workspace = nullptr): + problem_sizes0(args.problem_sizes0), + problem_sizes1(args.problem_sizes1), + problem_count(args.problem_count), + threadblock_count(args.threadblock_count), + ptr_Q(args.ptr_Q), + ptr_K(args.ptr_K), + ptr_P(args.ptr_P), + ptr_V(args.ptr_V), + ptr_O(args.ptr_O), + ptr_Max(args.ptr_Max), + ptr_Sum(args.ptr_Sum), + block_P(args.block_P), + block_Norm(args.block_Norm), + block_Sum(args.block_Sum), + offset_P(args.offset_P), + offset_Norm_Device(args.offset_Norm_Device), + offset_Sum_Device(args.offset_Sum_Device), + ldq(args.ldq), + ldk(args.ldk), + ldp(args.ldp), + ldv(args.ldv), + ldo(args.ldo), + problem_sizes0_host(args.problem_sizes0_host), + problem_sizes1_host(args.problem_sizes1_host), + problem_sizes0_real(args.problem_sizes0_real), + alpha0(args.alpha0), + alpha1(args.alpha1), + beta(args.beta), + head_number(args.head_number), + batch_size(args.batch_size), + seq_length(args.seq_length), + reduction(args.reduction) + { + + } + }; + + +private: + + Params params_; + GemmGrouped0 gemm_grouped0; + GemmGrouped1 gemm_grouped1; + + +public: + + /// Ctor + FusedMultiHeadAttention() { + + } + + /// Initialize + Status initialize(Arguments const &args, + void *workspace0 = nullptr, + void *workspace1 = nullptr) { + + params_ = Params(args); + + typename GemmGrouped0::Arguments args_gemm0( + params_.problem_sizes0, + params_.problem_count, + params_.threadblock_count, + params_.ptr_Q, + params_.ptr_K, + params_.ptr_P, + params_.ptr_P, + params_.ptr_Max, + params_.ptr_Sum, + params_.ldq, + params_.ldk, + params_.ldp, + params_.ldp, + typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments( + { + params_.alpha0, + params_.beta + } + ), + params_.problem_sizes0_host, + params_.problem_sizes0_real + ); + + + Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0); + + typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta); + + typename GemmGrouped1::Arguments args_gemm1( + params_.problem_sizes1, + params_.problem_count, + params_.threadblock_count, + epilogue_op1, + params_.ptr_P, + params_.ptr_V, + params_.ptr_O, + params_.ptr_O, + (void**)params_.ptr_Max, + (void**)params_.ptr_Sum, + params_.ldp, + params_.ldv, + params_.ldo, + params_.ldo, + params_.problem_sizes1_host + ); + + Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1); + + if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) { + return cutlass::Status::kSuccess; + }else{ + if (result0 != cutlass::Status::kSuccess) { + return result0; + }else{ + return result1; + } + } + } + + /// Run + Status run(cudaStream_t stream = nullptr) { + + Status result = gemm_grouped0.run(); + cudaError_t error_info; + + if (result != cutlass::Status::kSuccess) { + return cutlass::Status::kErrorInternal; + } + + int thread_per_block = 1024; + + dim3 final_reduction_grid(params_.head_number * params_.batch_size); + dim3 final_reduction_block(thread_per_block); + + cutlass::Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream + >>>(params_.reduction); + + error_info = cudaGetLastError(); + + if (error_info != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = gemm_grouped1.run(); + + if (result != cutlass::Status::kSuccess) { + return cutlass::Status::kErrorInternal; + } + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +} diff --git a/examples/41_multi_head_attention/gemm_grouped_with_softmax_visitor.h b/examples/41_multi_head_attention/gemm_grouped_with_softmax_visitor.h new file mode 100644 index 00000000..755c1252 --- /dev/null +++ b/examples/41_multi_head_attention/gemm_grouped_with_softmax_visitor.h @@ -0,0 +1,522 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped GEMM kernel with epilogue visitor customized for softmax +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform + bool Transposed_ = false, + bool UseMask_ = false +> +struct GemmGroupedWithEpilogueVistor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + + using EpilogueVisitor = typename Epilogue::Visitor; + using EpilogueOutputOp = typename EpilogueVisitor::ElementwiseFunctor; + static bool const kTransposed = Transposed_; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments< + typename Mma::IteratorA::Element, + typename Mma::IteratorA::Layout, + Mma::kTransformA, + Mma::IteratorA::AccessType::kElements, + typename Mma::IteratorB::Element, + typename Mma::IteratorB::Layout, + Mma::kTransformB, + Mma::IteratorB::AccessType::kElements, + typename Mma::LayoutC, + kTransposed + >; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename MapArguments::LayoutC; + + using ElementNorm = typename EpilogueVisitor::ElementNorm; + using ElementSum = typename EpilogueVisitor::ElementSum; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount, + kTransposed>; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord *problem_sizes; + // when using mask, real problem sizes may not be aligned + // then we need to mask out unpadded elements in softmax + GemmCoord *problem_sizes_real; + int problem_count; + int threadblock_count; + + ElementA ** ptr_A; + ElementB ** ptr_B; + ElementC ** ptr_C; + ElementC ** ptr_D; + + ElementNorm **ptr_Max; + ElementSum **ptr_Sum; + + typename LayoutA::Stride::LongIndex *lda; + typename LayoutB::Stride::LongIndex *ldb; + typename LayoutC::Stride::LongIndex *ldc; + typename LayoutC::Stride::LongIndex *ldd; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // Only used by device-level operator + GemmCoord *host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): + problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + host_problem_sizes(nullptr) + { + + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord *problem_sizes, + int problem_count, + int threadblock_count, + ElementA ** ptr_A, + ElementB ** ptr_B, + ElementC ** ptr_C, + ElementC ** ptr_D, + ElementNorm **ptr_Max, + ElementSum **ptr_Sum, + typename LayoutA::Stride::LongIndex *lda, + typename LayoutB::Stride::LongIndex *ldb, + typename LayoutC::Stride::LongIndex *ldc, + typename LayoutC::Stride::LongIndex *ldd, + typename EpilogueVisitor::Arguments epilogue_visitor_, + GemmCoord *host_problem_sizes=nullptr, + GemmCoord *problem_sizes_real=nullptr + ): + problem_sizes(problem_sizes), + problem_count(problem_count), + threadblock_count(threadblock_count), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + ptr_Max(ptr_Max), + ptr_Sum(ptr_Sum), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + epilogue_visitor(epilogue_visitor_), + host_problem_sizes(host_problem_sizes), + problem_sizes_real(problem_sizes_real) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + GemmCoord *problem_sizes_real; + int threadblock_count; + + ElementA ** ptr_A; + ElementB ** ptr_B; + ElementC ** ptr_C; + ElementC ** ptr_D; + + ElementNorm **ptr_Max; + ElementSum **ptr_Sum; + + typename LayoutA::Stride::LongIndex *lda; + typename LayoutB::Stride::LongIndex *ldb; + typename LayoutC::Stride::LongIndex *ldc; + typename LayoutC::Stride::LongIndex *ldd; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + problem_sizes_real(problem_sizes_real) + { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, void *workspace = nullptr, int32_t tile_count = 0): + problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D), + ptr_Max(args.ptr_Max), + ptr_Sum(args.ptr_Sum), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldd(args.ldd), + epilogue_visitor(args.epilogue_visitor), + problem_sizes_real(args.problem_sizes_real) + { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr, + int32_t tile_count = -1) { + + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_Max = args.ptr_Max; + ptr_Sum = args.ptr_Sum; + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + problem_sizes_real = args.problem_sizes_real; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmGroupedWithEpilogueVistor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static size_t get_extra_workspace_size( + Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Mma::LayoutC; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{ + 0, + threadblock_offset.n() + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), + ptr_B, + {problem_size.k(), problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + ElementC *ptr_C = params.ptr_C[problem_idx]; + ElementC *ptr_D = params.ptr_D[problem_idx]; + + ElementNorm *ptr_Max = params.ptr_Max[problem_idx]; + ElementSum *ptr_Sum = params.ptr_Sum[problem_idx]; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + int column_offset = (threadblock_offset.n() / ThreadblockShape::kN) * problem_size.m(); + + typename EpilogueVisitor::OutputTileIterator::Params params_C(layout_C); + typename EpilogueVisitor::OutputTileIterator::Params params_D(layout_D); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.kernel.epilogue.visitor, + problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params_C, + params_D, + ptr_C, + ptr_D, + ptr_Max, + ptr_Sum, + threadblock_offset.mn(), + column_offset, + params.problem_sizes_real[problem_idx].mn() + ); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.kernel.epilogue.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor + epilogue(epilogue_visitor, accumulators); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d5fdac14..02ae033a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -116,6 +116,10 @@ foreach(EXAMPLE 34_transposed_conv2d 35_gemm_softmax 36_gather_scatter_fusion + 37_gemm_layernorm_gemm_fusion + 38_syr2k_grouped + 39_gemm_permute + 41_multi_head_attention ) add_subdirectory(${EXAMPLE}) diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index 48a499c5..4b0ed6ad 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -98,6 +98,9 @@ template < bool IsHermitianData = false> struct cp_async_diag; +static const uint32_t OOB_NAN_F16 = 0x7eff; +static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16); + //////////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization @@ -190,8 +193,8 @@ struct cp_async_nan<16, CacheOperation::Always> { cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { #if CUDA_CP_ASYNC_ACTIVATED - static __constant__ uint4 OOB_NAN_F16x8 = {0x7eff7eff, 0x7eff7eff, - 0x7eff7eff, 0x7eff7eff}; + static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, + OOB_NAN_F16x2, OOB_NAN_F16x2}; unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); @@ -305,7 +308,6 @@ struct cp_async_diag { } }; - //////////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization @@ -386,6 +388,47 @@ struct cp_async_zfill { } }; +/// Partial specialization +template <> +struct cp_async_nan<16, CacheOperation::Global> { + static int const kSizeInBytes = 16; + + /// Copy with nan fill + CUTLASS_DEVICE + cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { + #if CUDA_CP_ASYNC_ACTIVATED + + static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, + OOB_NAN_F16x2, OOB_NAN_F16x2}; + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" +#else + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" +#endif + " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" + "}\n" + : + : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), + "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), + "r"(OOB_NAN_F16x8.w)); + + #else + + CUTLASS_UNUSED(smem_ptr); + CUTLASS_UNUSED(global_ptr); + CUTLASS_UNUSED(pred_guard); + CUTLASS_NOT_IMPLEMENTED(); + + #endif + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// /// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 816072bd..22b633fb 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -126,6 +126,10 @@ struct Mma< : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -188,6 +192,10 @@ struct Mma< ); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -251,6 +259,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -308,6 +320,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -366,6 +382,10 @@ struct Mma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -423,6 +443,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -486,6 +510,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -543,6 +571,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -600,6 +632,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -657,6 +693,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -711,7 +751,6 @@ struct Mma< unsigned const & A = reinterpret_cast(a); unsigned const & B = reinterpret_cast(b); - int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); @@ -720,6 +759,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -777,6 +820,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -834,6 +881,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -891,6 +942,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -954,6 +1009,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1011,6 +1070,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1068,6 +1131,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1125,6 +1192,10 @@ struct Mma< : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1210,11 +1281,19 @@ struct Mma< nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index 5b9f5240..36005174 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -587,6 +587,10 @@ struct Mma< "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1571,6 +1575,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1631,6 +1639,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1691,6 +1703,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1751,6 +1767,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1818,6 +1838,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1878,6 +1902,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1938,6 +1966,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1998,6 +2030,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -2059,6 +2095,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -2126,6 +2166,10 @@ struct Mma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED) diff --git a/include/cutlass/arch/mma_sparse_sm80.h b/include/cutlass/arch/mma_sparse_sm80.h index e22d6006..4df64bf3 100644 --- a/include/cutlass/arch/mma_sparse_sm80.h +++ b/include/cutlass/arch/mma_sparse_sm80.h @@ -141,6 +141,10 @@ struct SparseMma< assert(0); } #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -224,6 +228,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -296,6 +304,10 @@ struct SparseMma, 32, bfloat16_t, layout::RowMajor, #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -368,6 +380,10 @@ struct SparseMma, 32, tfloat32_t, layout::RowMajor, #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -449,6 +465,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -524,6 +544,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -599,6 +623,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -674,6 +702,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -755,6 +787,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -830,6 +866,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -905,6 +945,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -980,6 +1024,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1061,6 +1109,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1136,6 +1188,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1211,6 +1267,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1286,6 +1346,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1367,6 +1431,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1442,6 +1510,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1517,6 +1589,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -1592,6 +1668,10 @@ struct SparseMma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index af0b271a..10fe46a4 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -127,8 +127,8 @@ template class complex { public: - /// Type alias for scalar type - using value_type = T; + /// Type alias for scalar type + using value_type = T; private: // diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index 7d0c86f4..d33de182 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -268,7 +268,7 @@ public: CUTLASS_HOST_DEVICE int64_t filter_size() const { - return (K * R * S * C); + return (K * R * S * C / groups); } /// Returns output size in number of elements @@ -362,61 +362,128 @@ int implicit_gemm_k_iterations( Operator conv_operator, int threadblock_K, Conv2dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { int iterations = 0; - if (algorithm == IteratorAlgorithm::kFixedChannels) { + if (group_mode == GroupMode::kNone) { - int positions_per_iteration = threadblock_K / problem_size.C; - switch (conv_operator) { - case Operator::kFprop: - iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; - break; + if (algorithm == IteratorAlgorithm::kFixedChannels) { - default: - break; + int positions_per_iteration = threadblock_K / problem_size.C; + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; + break; + + default: + break; + } } - } - else if (algorithm == IteratorAlgorithm::kFewChannels) { + else if (algorithm == IteratorAlgorithm::kFewChannels) { - switch (conv_operator) { - case Operator::kFprop: - iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; - break; + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; + break; - default: - break; + default: + break; + } } - } - else { - int elements_per_split_k_slice = 0; + else { + int elements_per_split_k_slice = 0; - switch (conv_operator) { - case Operator::kFprop: - elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; - case Operator::kDgrad: - elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; - case Operator::kWgrad: - elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; - break; + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; - default: - break; + default: + break; + } } + + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } else { // Group conv + + int channels_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups + if (problem_size.groups != 1) { + if (k_per_group < threadblock_N) { + iterations *= threadblock_N / k_per_group; + } + } + break; + + default: + break; + } + } + } return iterations; } +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations_per_channel( + Operator conv_operator, + int threadblock_K, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + + int iterations = 0; //0 means not applicable + if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S; + break; + + case Operator::kDgrad: + iterations = problem_size.R * problem_size.S; + break; + + default: + break; + } + } + return iterations; +} + //////////////////////////////////////////////////////////////////////////////// // Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) //////////////////////////////////////////////////////////////////////////////// @@ -537,12 +604,12 @@ void strided_dgrad_starting_coords( // function locals for remainder by fast divmod int pad_h_rem_, pad_w_rem_; - // start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; + // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); stride_h_divmod.divmod(start_h, r_); - //start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; + //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); stride_w_divmod.divmod(start_w, s_); diff --git a/include/cutlass/conv/conv3d_problem_size.h b/include/cutlass/conv/conv3d_problem_size.h index 82ea1cef..da59a3a9 100644 --- a/include/cutlass/conv/conv3d_problem_size.h +++ b/include/cutlass/conv/conv3d_problem_size.h @@ -339,29 +339,46 @@ int implicit_gemm_k_iterations( Operator conv_operator, int threadblock_K, Conv3dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { int iterations = 0; int elements_per_split_k_slice = 0; + if (group_mode == GroupMode::kNone) { + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; + + default: + break; + } + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; - switch (conv_operator) { - case Operator::kFprop: - elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kDgrad: - elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kWgrad: - elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; - break; - - default: - break; + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.T * problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } } return iterations; diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index 52a4636c..372a60b9 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -117,6 +117,14 @@ enum class SplitKMode { kParallel }; +/// Identifies group mode +enum class GroupMode { + kNone, + kSingleGroup, ///< One CTA calculates one group or less + kMultipleGroup, ///< One CTA calculates multiple groups + kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace conv diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 8e87ec56..bac90f15 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -78,6 +78,7 @@ public: static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator; static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm; static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = ImplicitGemmKernel::kGroupMode; static int const kWarpCount = (ThreadblockShape::kM / WarpShape::kM) * @@ -111,6 +112,34 @@ public: return status; } + // check group conv constraint + if (args.problem_size.groups != 1) { + if (kGroupMode == conv::GroupMode::kNone) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K % args.problem_size.groups || + args.problem_size.C % args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + // split-k is not supported + if (args.problem_size.split_k_slices != 1) { + return Status::kErrorInvalidProblem; + } + + int k_per_group = args.problem_size.K / args.problem_size.groups; + // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group + if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { + return Status::kErrorInvalidProblem; + } + // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups + if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { + return Status::kErrorInvalidProblem; + } + } + static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess; if (kConvolutionalOperator == conv::Operator::kFprop) { if (args.problem_size.K % kAlignmentC) diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h index 4fc2200a..89a02886 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -45,8 +45,8 @@ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" #include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -#include "cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h" -#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -161,7 +161,7 @@ struct DefaultConv2dFpropFusion < LayoutScaleBias>; using SmemIteratorScaleBias = - cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, LayoutScaleBias>; @@ -172,7 +172,7 @@ struct DefaultConv2dFpropFusion < static int const kThreadCount = 32; // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias< + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< MatrixShape, ElementScaleBias, LayoutScaleBias, MatrixShape, typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, @@ -296,7 +296,7 @@ struct DefaultConv2dFpropFusion < LayoutScaleBias>; using SmemIteratorScaleBias = - cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, LayoutScaleBias>; @@ -307,7 +307,7 @@ struct DefaultConv2dFpropFusion < static int const kThreadCount = 32; // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias< + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< MatrixShape, ElementScaleBias, LayoutScaleBias, MatrixShape, typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, diff --git a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h new file mode 100644 index 00000000..b17e7c59 --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dGroupFpro +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dGroupFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h new file mode 100644 index 00000000..80b220f4 --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution + definitions that combine threadblock-scoped matrix multiply-add with the + appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for fused batch norm and Conv3dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultConv3dFpropFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_depthwise_fprop.h b/include/cutlass/conv/kernel/default_depthwise_fprop.h new file mode 100644 index 00000000..b4005a4a --- /dev/null +++ b/include/cutlass/conv/kernel/default_depthwise_fprop.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value +> struct DefaultDepthwiseFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, // cutlass::arch::OpMultiplyAdd + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize< + ThreadblockShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + sizeof_bits::value, + 2, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + cutlass::conv::GroupMode::kDepthwise + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index d3f1a19f..df75fcfb 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -62,7 +62,8 @@ template < typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_, ///! Threadblock swizzling function conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) - typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode > struct ImplicitGemmConvolution { @@ -117,6 +118,8 @@ struct ImplicitGemmConvolution { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = GroupMode_; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) @@ -198,6 +201,7 @@ struct ImplicitGemmConvolution { int swizzle_log_tile; int gemm_k_iterations; + int gemm_k_iterations_per_channel; typename Mma::IteratorA::Params iterator_A; typename Mma::IteratorA::Element const *ptr_A; typename Mma::IteratorB::Params iterator_B; @@ -241,7 +245,12 @@ struct ImplicitGemmConvolution { kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, - kIteratorAlgorithm); + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, kIteratorAlgorithm); ThreadblockSwizzle threadblock_swizzle; @@ -286,6 +295,17 @@ struct ImplicitGemmConvolution { // Compute position within threadblock int thread_idx = threadIdx.x; + int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode != GroupMode::kDepthwise) { + int k_per_group = params.problem_size.K / params.problem_size.groups; + int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group; + int channels_per_group = params.problem_size.C / params.problem_size.groups; + iterator_A_column_offset += group_idx * channels_per_group; + } else { + iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( @@ -295,7 +315,7 @@ struct ImplicitGemmConvolution { thread_idx, MatrixCoord( threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.k() * Mma::Shape::kK + iterator_A_column_offset ) ); @@ -327,7 +347,7 @@ struct ImplicitGemmConvolution { accumulators.clear(); // Compute threadblock-scoped matrix multiply-add - mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel); // // Epilogue diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h index d43521f1..f99f9a6c 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h @@ -119,6 +119,8 @@ struct ImplicitGemmConvolutionFusion { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 65191f5a..949dbdf5 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionStridedDgrad { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) @@ -488,4 +490,3 @@ struct ImplicitGemmConvolutionStridedDgrad { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index 2ab47637..3b5f3731 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionWithFusedEpilogue { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h index 4014173b..9c9e8a71 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -248,7 +248,7 @@ public: pointer_ += pointer_offset * sizeof_bits::value / 8; } - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE void advance() { int next_idx = 0; @@ -263,18 +263,33 @@ public: // Move filter_r by stride_h filter_r_ += problem_size_.stride_h; - +#if 0 bool check = (filter_r_ < problem_size_.R); filter_r_ = check ? filter_r_ : start_r_; next_idx = check ? 1 : 2; reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_); +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " .reg .s64 t1;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 t1, %6, %7, %%p;\n\t" + " add.s64 %2, %8, t1;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes)); +#endif } // offset pointers by offset_bytes pointer_ += (params_.inc_next[next_idx] - reset_bytes); - if (next_idx == 2) { + if (next_idx == 2) { filter_k_ += params_.filter_k_delta; } diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index 80448f36..5576c818 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -528,7 +528,6 @@ public: int k = filter_k_ + iteration_vector_ * AccessType::kElements; return TensorCoord(n, p, q, k); - } /// Returns true if the current coordinate is within the output tensor Dy diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index 4b1e906a..a825f4ce 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -321,7 +321,7 @@ public: add_byte_offset_(pointer_offset * sizeof_bits::value / 8); } - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE void advance() { int next_idx = 0; @@ -336,8 +336,9 @@ public: // Move filter_r by stride_h filter_r_ += problem_size_.stride_h; +#if 0 if (filter_r_ < problem_size_.R) { - + next_idx = 1; // Restore bytes in q coordinate (Mma in filter s dimenstion) @@ -347,12 +348,25 @@ public: // Restore filter_r filter_r_ = start_r_; - + next_idx = 2; - + // Restore bytes in p and q coordinate (Mma in filter s and r dimenstion) reset_bytes = reset_bytes_r_; } +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 %2, %6, %7, %%p;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_)); +#endif } // offset pointers by offset_bytes diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h index add089af..a9ebaebd 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -67,7 +67,8 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone > class Conv2dFpropActivationTileAccessIteratorAnalytic { public: @@ -89,6 +90,7 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; @@ -119,6 +121,11 @@ private: int filter_c_; int filter_r_; int filter_s_; + int filter_c_init_; + int group_idx_offset_; + int channels_per_group_; + int crs_cnt_; + int crs_per_group_; int offset_n_[ThreadMap::Iterations::kStrided]; int offset_p_[ThreadMap::Iterations::kStrided]; @@ -137,6 +144,8 @@ public: params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_(0), filter_c_(0), filter_r_(0), filter_s_(0) { @@ -145,6 +154,12 @@ public: filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + channels_per_group_ = problem_size_.C / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn); + } + CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; @@ -182,6 +197,10 @@ public: CUTLASS_HOST_DEVICE void advance() { // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + ++filter_s_; if (filter_s_ < problem_size_.S) { return; @@ -192,8 +211,19 @@ public: return; } filter_r_ = 0; - - filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + // moves to next group + crs_cnt_ = 0; + ++group_idx_offset_; + filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_; + } else { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } + } } /// Returns the coordinate in the activations tensor X that is currently pointed to diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index 08d3176d..626dc800 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -66,7 +66,8 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone > class Conv2dFpropFilterTileAccessIteratorAnalytic { public: @@ -88,6 +89,7 @@ public: static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; @@ -118,8 +120,14 @@ private: int filter_r_; int filter_s_; int filter_c_; + int filter_c_init_; + int crs_cnt_; + int crs_per_group_; + int group_idx_offset_c_; + int channels_per_group_; int offset_k_[ThreadMap::Iterations::kStrided]; + int group_idx_offset_k_[ThreadMap::Iterations::kStrided]; public: @@ -134,6 +142,8 @@ public: params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_c_(0), filter_r_(0), filter_s_(0), filter_c_(0) { @@ -142,9 +152,23 @@ public: filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + if (kGroupMode == conv::GroupMode::kDepthwise){ + channels_per_group_ = 1; + crs_per_group_ = problem_size_.S * problem_size_.R; + } else { + channels_per_group_ = problem_size_.C / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); + } + } + CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { + group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups); + } } set_iteration_index(0); @@ -168,6 +192,10 @@ public: CUTLASS_HOST_DEVICE void advance() { // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + ++filter_s_; if (filter_s_ < problem_size_.S) { return; @@ -179,8 +207,21 @@ public: return; } filter_r_ = 0; - - filter_c_ += Shape::kRow * problem_size_.split_k_slices; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + crs_cnt_ = 0; + filter_c_ = filter_c_init_; + if (kGroupMode != conv::GroupMode::kDepthwise) { + // moves to next group + ++group_idx_offset_c_; + } + } else { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } + } } /// Returns the coordinate in the filter tensor W that is currently pointed to @@ -200,8 +241,14 @@ public: TensorCoord coord = at(); - return coord.n() < problem_size_.K && - coord.c() < problem_size_.C; + if (kGroupMode == conv::GroupMode::kNone) { + return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + } else if (kGroupMode == conv::GroupMode::kDepthwise) { + return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE. + } else { + return coord.n() < problem_size_.K && coord.c() < channels_per_group_ && + group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; + } } /// Returns a pointer to the vector starting at the current coordinate diff --git a/include/cutlass/conv/threadblock/conv2d_params.h b/include/cutlass/conv/threadblock/conv2d_params.h index 1ba9532c..0d8fc83e 100644 --- a/include/cutlass/conv/threadblock/conv2d_params.h +++ b/include/cutlass/conv/threadblock/conv2d_params.h @@ -554,20 +554,20 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams { // next S inc_next[0] = conv_sign * ( - layout.stride()[0] * problem_size.dilation_w + (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // next R inc_next[1] = conv_sign * ( - layout.stride()[1] * problem_size.dilation_h - - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + (int64_t)layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // next K inc_next[2] = ( threadblock_shape.column() * problem_size.split_k_slices - - conv_sign * (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h - - conv_sign * (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + - conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h + - conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // logical offset added to internal channel counter - units are elements, not bytes @@ -614,12 +614,12 @@ struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams { // next S inc_next[0] = conv_sign * ( - layout.stride()[0] * problem_size.dilation_w + (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // next R inc_next[1] = conv_sign * ( - layout.stride()[1] * problem_size.dilation_h + (int64_t)layout.stride()[1] * problem_size.dilation_h ) * element_size_bits / 8; // next K @@ -670,18 +670,18 @@ struct Conv2dDgradFilterIteratorOptimizedParams { TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; + inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; inc_next_rs = - ( layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] ) * element_size_bits / 8; inc_next_k = ( - threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2] - - (problem_size.R * problem_size.S - 1) * layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] + - (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] ) * element_size_bits / 8; filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; @@ -730,26 +730,26 @@ struct Conv2dStridedDgradFilterIteratorOptimizedParams { // next S inc_next[0] = - ( layout.stride()[0] * problem_size.stride_w + ( (int64_t)layout.stride()[0] * problem_size.stride_w //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] ) * element_size_bits / 8; // next R inc_next[1] = - ( layout.stride()[1] * problem_size.stride_h + ( (int64_t)layout.stride()[1] * problem_size.stride_h //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] ) * element_size_bits / 8; // next K inc_next[2] = ( - threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2] + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] //- (problem_size.R * problem_size.S - 1) * layout.stride()[0] //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] ) * element_size_bits / 8; // offset in units of bytes to move the pointer in backward direction - reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] * element_size_bits / 8; filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; @@ -800,13 +800,13 @@ struct Conv2dWgradOutputGradientIteratorOptimizedParams { element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 - offset_next_strided = (threadmap_delta.strided() * layout.stride()[0]) + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) * element_size_bits / 8; offset_next_contiguous = (threadmap_delta.contiguous()) * element_size_bits / 8; - inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0]) + inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) * element_size_bits / 8; } }; @@ -891,4 +891,3 @@ struct PredicatedScaleBiasVectorAccessIteratorParams { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h index 66dd75d2..595497cf 100644 --- a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -104,6 +104,11 @@ public: return TileAccessIterator::getParams(problem_size, layout); } + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + tile_access_iterator_.set_iteration_index(index); + } /// Adds a pointer offset in units of Element CUTLASS_HOST_DEVICE diff --git a/include/cutlass/conv/threadblock/conv3d_params.h b/include/cutlass/conv/threadblock/conv3d_params.h index 5ad1e4fa..4ba2960f 100644 --- a/include/cutlass/conv/threadblock/conv3d_params.h +++ b/include/cutlass/conv/threadblock/conv3d_params.h @@ -304,8 +304,8 @@ struct Conv3dDgradOutputGradientIteratorOptimizedParams { // logical offset added to internal channel counter - units are elements, not bytes filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; } - }; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Parameters object for Conv2d DGRAD Filter (w) iterator @@ -343,18 +343,18 @@ struct Conv3dDgradFilterIteratorOptimizedParams { TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter", element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - inc_next_strided = (layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; + inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; inc_next_trs = - ( layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] ) * element_size_bits / 8; inc_next_k = ( - threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[3] - - (problem_size.T * problem_size.R * problem_size.S - 1) * layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3] + - (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] ) * element_size_bits / 8; filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; @@ -408,13 +408,13 @@ struct Conv3dWgradOutputGradientIteratorOptimizedParams { element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 - offset_next_strided = (threadmap_delta.strided() * layout.stride()[0]) + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) * element_size_bits / 8; offset_next_contiguous = (threadmap_delta.contiguous()) * element_size_bits / 8; - inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0]) + inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) * element_size_bits / 8; // Precompute several quantities for fast modulo arithmetic. diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h new file mode 100644 index 00000000..6ae75fd9 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h @@ -0,0 +1,336 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to A operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::MmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropPipelined( + typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum, ///< source accumulator tile + int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + // Depthwise specific + int channel_start_index = 0; + int rs_plane_idx = 0; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Reset interation index. + iterator_B.set_iteration_index(0); + } + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Move to next set of filter groups. + channel_start_index += Base::kWarpGemmIterations; + } + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); + } + + rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1); + + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h new file mode 100644 index 00000000..f13bdc32 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h @@ -0,0 +1,337 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data + layout of the global memory fragments, data types, and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/warp/mma.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_singlestage.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/conv/warp/mma_depthwise_simt.h" + +#include "cutlass/arch/cache_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// per-element transformation for elements of A + ComplexTransform TransformA, + /// per-element transformation for elements of B + ComplexTransform TransformB, + bool IsComplex +> +struct DepthwiseMmaCoreWithLaneAccessSize< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> : cutlass::gemm::threadblock::DefaultMmaCore< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeA_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeB_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + kLaneAccessSizeB_, + 2, + Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_> { + using Base = cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_>; + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeA = kLaneAccessSizeA_; + static int const kLaneAccessSizeB = kLaneAccessSizeB_; + + // Divisility requirements + static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = typename Base::WarpCount; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory are same as base class + // + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m(); + static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = kLaneAccessSizeA / sizeof_bits::value; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + + static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + WarpCount::kK + >; +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass diff --git a/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h index 4f16b42d..d47d9a29 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h @@ -64,7 +64,7 @@ #include "cutlass/arch/cache_operation.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" #include "cutlass/conv/warp/scale_bias_relu_transform.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -139,6 +139,13 @@ class MmaFpropFusionBase { /// Tensor reference to the B operand using TensorRefB = TensorRef; + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + // // Nested structs // @@ -319,7 +326,7 @@ class ImplicitGemmFpropFusionMultistage using Policy = Policy_; ///< Base class using Base = MmaFpropFusionBase; using SmemIteratorA = SmemIteratorA_; @@ -518,6 +525,8 @@ public: IteratorScaleBias iterator_A_scale_bias, ///< initial value of accumulator FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, ///< Imaginary strides used for planar-complex only - ignored here int64_t imag_stride_A = 0, int64_t imag_stride_B = 0) { diff --git a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index 36b41aac..9be1be77 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -116,10 +116,6 @@ public: /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -272,6 +268,8 @@ public: IteratorB iterator_B, ///< initial value of accumulator FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, ///< Imaginary strides used for planar-complex only - ignored here int64_t imag_stride_A = 0, int64_t imag_stride_B = 0) { @@ -297,7 +295,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; @@ -322,7 +320,7 @@ public: this->smem_iterator_B_.get()); CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / diff --git a/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h index f77e2e33..aade7ac4 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h @@ -188,6 +188,7 @@ public: IteratorA iterator_A, ///< iterator over A operand in global memory IteratorB iterator_B, ///< iterator over B operand in global memory FragmentC const &src_accum, ///< source accumulator tile + int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel TransformA transform_A = TransformA(), ///< transformation applied to A fragment TransformB transform_B = TransformB()) { ///< transformation applied to B fragment diff --git a/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h index 7066997a..f4bb0ee7 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_wgrad_fusion_multistage.h @@ -70,7 +70,7 @@ #include "cutlass/arch/cache_operation.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" #include "cutlass/conv/warp/scale_bias_relu_transform.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -138,6 +138,13 @@ class MmaWgradFusionBase { /// Tensor reference to the B operand using TensorRefB = TensorRef; + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + // // Nested structs // @@ -306,10 +313,6 @@ class ImplicitGemmWgradFusionMultistage /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -470,6 +473,8 @@ public: IteratorScaleBias iterator_B_scale_bias, ///< initial value of accumulator FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, ///< Imaginary strides used for planar-complex only - ignored here int64_t imag_stride_A = 0, int64_t imag_stride_B = 0) { diff --git a/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h index 7d60e4b0..a0fb8104 100644 --- a/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h +++ b/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -113,12 +113,9 @@ class PredicatedScaleBiasVectorAccessIterator( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv3dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_trs(problem_size.T * problem_size.R * problem_size.S), + problem_size_c(problem_size.C), + filter_trs_(0) { pointer_ = (thread_id < kThreads) ? reinterpret_cast( const_cast(scale_pointer)) @@ -177,6 +207,22 @@ class PredicatedScaleBiasVectorAccessIterator + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseSimt + : public cutlass::gemm::warp:: + MmaSimt { + using Base = cutlass::gemm::warp:: + MmaSimt; + +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; // < 64, 16 , 8> + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + +public: + + /// Iterates over the B operand in memory + using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseSimt():Base() {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h new file mode 100644 index 00000000..9fec53f1 --- /dev/null +++ b/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT + instructions +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" + +#include "cutlass/layout/matrix.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions +/// +/// concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1 +> +class DepthwiseMmaSimtTileIterator; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization for B operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseMmaSimtTileIterator + : public cutlass::gemm::warp::MmaSimtTileIterator { + + using Base = cutlass::gemm::warp::MmaSimtTileIterator; + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = typename Base::TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Thread-level shape of a fragment + using ThreadShape = typename Base::ThreadShape; + + /// Number of individual loads + using Iterations = typename Base::Iterations; + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + + static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim"); + +private: + + MatrixCoord lane_offset_; + int channel_idx_; + int base_channel_idx_; + int warps_n_; + + public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator():Base() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator( + TensorRef ref, + int lane_id + ) : Base(ref, lane_id) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + warps_n_ = -1; + channel_idx_ = 0; + base_channel_idx_ = 0; + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + + if(warps_n_ == -1){ + warps_n_ = coord.column(); + } + + Base::add_tile_offset(coord); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < Iterations::kRow; ++k) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + + void const *ptr = this->ref_.data() + + this->ref_.offset({-(channel_idx_ - base_channel_idx_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + + // Base_k of a warp + Base_k of current threads. + int thread_k_base_idx = + warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column(); + + if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) { + // Depthwise kernel would only do computation when channel == k. + // Loads an element when the current computation channel == the k corresponding to this thread. + arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); + } else { + // Reduce SMEM load + dst_ptr[n + k * Iterations::kColumn].fill(Element(0)); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + if(k_group % PartitionGroupSize == 0 && k_group != 0){ + base_channel_idx_ = k_group; + } + channel_idx_ = k_group; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/conv/warp/scale_bias_relu_transform.h b/include/cutlass/conv/warp/scale_bias_relu_transform.h index 5bcbfcd0..2944c43c 100644 --- a/include/cutlass/conv/warp/scale_bias_relu_transform.h +++ b/include/cutlass/conv/warp/scale_bias_relu_transform.h @@ -101,7 +101,7 @@ struct FpropScaleBiasReluTransform { "}\n" : "=r"(ptr_activations[0]) : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), - "r"(ptr_scale_bias[1]), "n"(0x7eff7eff)); + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); #else // TODO: write emulation code assert(0); @@ -151,8 +151,8 @@ struct WgradScaleBiasReluTransform { #if 1 // CUDA + PTX version - bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == 0x7eff); - bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == 0x7eff); + bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16); + bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16); // Apply per channel scale+bias+relu if the data is not a special NaN // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. @@ -161,7 +161,7 @@ struct WgradScaleBiasReluTransform { // out-of-bound because C x R x S can be an odd number. asm volatile( "{\n\t" - " fma.rn.f16x2.relu %0 , %1, %2, %3;\n" + " fma.rn.f16x2.relu %0, %1, %2, %3;\n" "}" : "=r"(reinterpret_cast(ptr_activations[0])) : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), @@ -195,7 +195,7 @@ struct WgradScaleBiasReluTransform { "}\n" : "=r"(reinterpret_cast(ptr_activations[0])) : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), - "r"(ptr_scale_bias[1]), "n"(0x7eff), "n"(0xffff0000), "n"(0x0000ffff)); + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); #endif #else // TODO: write emulation code diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index ebc4c1e9..32eb6e00 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -43,7 +43,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0) +#define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) #if !defined(__CUDACC_RTC__) @@ -192,4 +192,3 @@ CUTLASS_HOST_DEVICE bool thread0() { } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 9763f5fc..6706af5b 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -98,6 +98,32 @@ struct ReLu> { } }; +// Leaky Relu operator +template +struct LeakyReLU { + CUTLASS_HOST_DEVICE + T operator()(T const &value, T const & alpha_recip) const { + T res = value > T(0) ? value : value * alpha_recip; + return res; + } +}; + +template +struct LeakyReLU > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs, T const & alpha_recip) const { + Array y; + LeakyReLU leaky_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(rhs.size()); ++i) { + y[i] = leaky_op(rhs[i], alpha_recip); + } + + return y; + } +}; + // Tanh operator template struct Tanh { @@ -135,32 +161,6 @@ struct Tanh> { } }; -// Leaky Relu operator -template -struct LeakyReLU { - CUTLASS_HOST_DEVICE - T operator()(T const &value, T const & alpha_recip) const { - T res = value > T(0) ? value : value * alpha_recip; - return res; - } -}; - -template -struct LeakyReLU > { - CUTLASS_HOST_DEVICE - Array operator()(Array const &rhs, T const & alpha_recip) const { - Array y; - LeakyReLU leaky_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(rhs.size()); ++i) { - y[i] = leaky_op(rhs[i], alpha_recip); - } - - return y; - } -}; - // Sigmoid operator template struct Sigmoid { diff --git a/include/cutlass/epilogue/thread/linear_combination_generic.h b/include/cutlass/epilogue/thread/linear_combination_generic.h index d43ce5c4..9f184f85 100644 --- a/include/cutlass/epilogue/thread/linear_combination_generic.h +++ b/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -157,7 +157,7 @@ public: if (k_partition) { beta_ = ElementCompute(1); } - + if (k_partition != k_partition_count - 1) { skip_elementwise_ = true; } diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h index ce665251..cfaaa8b4 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h @@ -65,6 +65,8 @@ #include "cutlass/epilogue/threadblock/shared_load_iterator.h" #include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/layout/permute.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -79,7 +81,8 @@ template < typename WarpMmaSimt_, typename OutputOp_, int ElementsPerAccess, - bool ScatterD = false + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute > struct DefaultEpilogueSimt { @@ -109,7 +112,8 @@ struct DefaultEpilogueSimt { using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput, - ScatterD + ScatterD, + PermuteDLayout >; using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< @@ -310,7 +314,6 @@ struct DefaultEpilogueSimtAffineRankN { }; ///////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index 46f23e1b..c232a2db 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -74,6 +74,8 @@ #include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/interleaved_epilogue.h" +#include "cutlass/layout/permute.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -166,7 +168,7 @@ template < typename ThreadMap > struct DefaultIteratorsTensorOp { - + using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOp< WarpShape, InstructionShape, @@ -265,7 +267,7 @@ struct DefaultIteratorsTensorOp< layout::RowMajor >; - using WarpTileIterator = typename cutlass::platform::conditional< + using WarpTileIterator = typename platform::conditional< (ThreadblockShape::kN == 256), WarpTileIteratorNotMixed, WarpTileIteratorMixed>::type; @@ -284,7 +286,7 @@ struct DefaultIteratorsTensorOp< int32_t >; - using SharedLoadIterator = typename cutlass::platform::conditional< + using SharedLoadIterator = typename platform::conditional< (ThreadblockShape::kN == 256), SharedLoadIteratorNotMixed, SharedLoadIteratorMixed>::type; @@ -302,7 +304,8 @@ template < int PartitionsK, typename OutputOp_, int ElementsPerAccess, - bool ScatterD = false + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute > struct DefaultEpilogueTensorOp { @@ -334,6 +337,7 @@ struct DefaultEpilogueTensorOp { OutputTileThreadMap, ElementOutput, ScatterD, + PermuteDLayout, UseCUDAStore >; @@ -570,7 +574,6 @@ struct DefaultEpilogueTensorOpAffineRankN { }; //////////////////////////////////////////////////////////////////////////////// - /// Defines sensible defaults for epilogues for TensorOps which uses /// intereleaved output layout. For this case, shared memory is not needed. template struct DefaultEpilogueVoltaTensorOp { @@ -111,7 +114,8 @@ struct DefaultEpilogueVoltaTensorOp { using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput, - ScatterD + ScatterD, + PermuteDLayout >; using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorVoltaTensorOp< @@ -326,7 +330,6 @@ struct DefaultEpilogueVoltaTensorOpAffineRankN { }; ///////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace threadblock } // namespace epilogue } // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h index ebac2c46..4cc8faf5 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h @@ -49,6 +49,8 @@ #include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" +#include "cutlass/layout/permute.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -67,7 +69,8 @@ template < typename ElementVector, typename OutputOp, int ElementsPerAccess, - bool ScatterD = false + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute > struct DefaultEpilogueWithBroadcastTensorOp { @@ -86,7 +89,8 @@ struct DefaultEpilogueWithBroadcastTensorOp { using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< typename Base::OutputTileThreadMap, ElementOutput, - ScatterD + ScatterD, + PermuteDLayout >; // diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h index 223d7a7c..918d3790 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_reduction.h @@ -50,6 +50,8 @@ #include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/epilogue_with_reduction.h" +#include "cutlass/layout/permute.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -67,7 +69,8 @@ template < typename OutputOp, typename ReductionOp, int ElementsPerAccess, - bool ScatterD = false + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute > struct DefaultEpilogueWithReductionTensorOp { @@ -89,7 +92,8 @@ struct DefaultEpilogueWithReductionTensorOp { using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< typename Base::OutputTileThreadMap, ElementOutput, - ScatterD + ScatterD, + PermuteDLayout >; /// Define the epilogue @@ -120,7 +124,8 @@ template < typename OutputOp, typename ReductionOp, int ElementsPerAccess, - bool ScatterD = false + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute > struct DefaultEpilogueWithReductionVoltaTensorOp { @@ -142,7 +147,8 @@ struct DefaultEpilogueWithReductionVoltaTensorOp { using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< typename Base::OutputTileThreadMap, ElementOutput, - ScatterD + ScatterD, + PermuteDLayout >; /// Define the epilogue diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h index e35065e3..7acdbf68 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h @@ -64,6 +64,8 @@ #include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/layout/permute.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -79,7 +81,8 @@ template < int PartitionsK, typename OutputOp_, int ElementsPerAccess, - bool ScatterD = false + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute > struct DefaultEpilogueWmmaTensorOp { @@ -109,7 +112,8 @@ struct DefaultEpilogueWmmaTensorOp { using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput, - ScatterD + ScatterD, + PermuteDLayout >; using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorWmmaTensorOp< diff --git a/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h b/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h new file mode 100644 index 00000000..9e2ffd28 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h @@ -0,0 +1,513 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue. + + The epilogue finds max values in each row of the row-major output matrix and stores them. + The max values are also used for a further round of threadblock scoped reduction operation, where + the partial reduction results are stored in a pre-allocated array and used for further full reduction. + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/fast_math.h" + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template < + typename ThreadblockShape_, + int ThreadCount, + typename OutputTileIterator_, + typename ElementAccumulator_, + typename ElementNorm_, + typename ElementSum_, + typename ElementSoftmaxCompute_, + typename ElementwiseFunctor_, + bool UseMasking_ = false +> +class EpilogueVisitorSoftmax { +public: + + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using ElementNorm = ElementNorm_; + using ElementSum = ElementSum_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; + + using AccumulatorFragment = Array; + using SoftmaxFragment = Array; + using OutputVector = Array; + using TensorRefD = TensorRef; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + static bool const kUseMasking = UseMasking_; + + /// Argument structure + struct Arguments { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Max; + int64_t batch_stride_Sum; + + // + // Methods + // + Arguments(): + batch_stride_C(0), + batch_stride_D(0), + batch_stride_Max(0), + batch_stride_Sum(0) + { + + } + + Arguments( + typename ElementwiseFunctor::Params elementwise_ + ): + elementwise(elementwise_), + batch_stride_C(0), + batch_stride_D(0), + batch_stride_Max(0), + batch_stride_Sum(0) + { + + } + + Arguments( + typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_C_, + int64_t batch_stride_D_, + int64_t batch_stride_Max_, + int64_t batch_stride_Sum_ + ): + elementwise(elementwise_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_), + batch_stride_Max(batch_stride_Max_), + batch_stride_Sum(batch_stride_Sum_) + { + + } + + }; + + struct Params { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Max; + int64_t batch_stride_Sum; + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() + { + + } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + elementwise(args.elementwise), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + batch_stride_Max(args.batch_stride_Max), + batch_stride_Sum(args.batch_stride_Sum) + { + + } + }; + + /// Shared storage + struct SharedStorage { + + }; + +private: + + Params const & params_; + SharedStorage & shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator alpha_; + ElementAccumulator beta_; + + ElementNorm *ptr_Max_; + ElementSum *ptr_Sum_; + + int column_offset_; + + ElementSoftmaxCompute accum_max_; + ElementSoftmaxCompute accum_sum_; + + MatrixCoord thread_offset_; + + float infinity_; + +public: + + CUTLASS_DEVICE + EpilogueVisitorSoftmax( + Params const ¶ms, + SharedStorage &shared_storage, + cutlass::MatrixCoord const &problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + typename OutputTileIterator::Element *ptr_C, + typename OutputTileIterator::Element *ptr_D, + ElementNorm *ptr_Max = nullptr, + ElementSum *ptr_Sum = nullptr, + cutlass::MatrixCoord const &threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const &problem_size_real = cutlass::MatrixCoord(0, 0), + float infinity = 10000.0f + ): + params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + ptr_Max_(ptr_Max), + ptr_Sum_(ptr_Sum), + column_offset_(column_offset), + extent_real_(problem_size_real), + infinity_(infinity) + { + alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // Clear accumulators for max and sum when starting a whole row + clear_accum_(); + + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const &accum) { + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + Minus minus; + Exp exponential; + + SoftmaxFragment result; + + NumericArrayConverter source_converter; + OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; + + if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + result = source_converter(elementwise_(accum)); + }else{ + result = source_converter(elementwise_(accum, source_vector)); + } + + thread_offset_ = + iterator_D_.thread_start() + + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + + bool column_guard = (thread_offset_.column() < extent_.column()); + + if (kUseMasking) { + int elements_in_boundary = extent_real_.column() - thread_offset_.column(); + elements_in_boundary = (elements_in_boundary > kElementsPerAccess) ? kElementsPerAccess : elements_in_boundary; + elementwise_padding_(result, elements_in_boundary); + } + + ElementSoftmaxCompute accum_max_prev = accum_max_; + + // Compute the maximum within one row + if (!column_idx) { + // This is the first fragment in a new row + if (column_guard) { + accum_max_ = maximum_accumulator_(result); + } + } + else { + // This is an additional fragment in the same row + if (column_guard) { + accum_max_ = maximum_accumulator_(result, accum_max_); + } + } + + // proactively compute max in warps + accum_max_ = warp_reduce_max_(accum_max_); + + ElementSoftmaxCompute updater = fast_exp(accum_max_prev - accum_max_); + + SoftmaxFragment intermediate = exponential(minus(result, accum_max_)); + + if (kHasMultiStepsInRow) { + if (!column_idx) { + accum_sum_ = (column_guard) ? \ + sum_accumulator_(intermediate) : ElementSoftmaxCompute(0); + } else { + // Algorithm in $3.1, https://arxiv.org/pdf/2205.14135v1.pdf + // S* = S* x updater + sum_row(P'), where updater = exp(M* - M_row) + accum_sum_ = (column_guard) ? \ + sum_accumulator_(intermediate, accum_sum_ * updater) : accum_sum_ * updater; + } + } else { + accum_sum_ = (column_guard) ? sum_accumulator_(intermediate, accum_sum_) : ElementSoftmaxCompute(0); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) { + + using ConvertSumOutput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + ConvertSumOutput convert_sum_output; + ConvertNormOutput convert_norm_output; + + // Compute accumulate sum only in the last step + accum_sum_ = warp_reduce_sum_(accum_sum_); + + bool is_first_thread_in_tile = ((threadIdx.x % kThreadsPerRow) == 0); + bool row_guard = thread_offset_.row() < extent_.row(); + bool is_write_thread = row_guard && is_first_thread_in_tile; + + int block_batch = blockIdx.z; + + ElementNorm *curr_ptr_max = ptr_Max_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Max; + ElementSum *curr_ptr_sum = ptr_Sum_ + thread_offset_.row() + column_offset_ + block_batch * params_.batch_stride_Sum; + + arch::global_store( + convert_norm_output(accum_max_), + (void *)curr_ptr_max, + is_write_thread); + + arch::global_store( + convert_sum_output(accum_sum_), + (void *)curr_ptr_sum, + is_write_thread); + + // Clear accumulators for max and sum when finishing a whole row + clear_accum_(); + + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() { + + } + +private: + + CUTLASS_DEVICE + void elementwise_padding_(SoftmaxFragment &result, int elements_in_boundary) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SoftmaxFragment::kElements; ++i) { + result[i] = (i < elements_in_boundary) ? result[i] : ElementSoftmaxCompute(-infinity_); + } + } + + CUTLASS_DEVICE + ElementSoftmaxCompute warp_reduce_sum_(ElementSoftmaxCompute sum_) { + int half_thread_in_row = (kThreadsPerRow >> 1); + CUTLASS_PRAGMA_UNROLL + for (int i = half_thread_in_row; i > 0; i >>= 1) { + ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, sum_, i); + sum_ += tmp; + } + return sum_; + } + + CUTLASS_DEVICE + ElementSoftmaxCompute warp_reduce_max_(ElementSoftmaxCompute max_) { + int half_thread_in_row = (kThreadsPerRow >> 1); + CUTLASS_PRAGMA_UNROLL + for (int i = half_thread_in_row; i > 0; i >>= 1) { + ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, max_, i); + max_ = fast_max(max_, tmp); + } + return max_; + } + + CUTLASS_DEVICE + void clear_accum_() { + + uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX + float min_float = reinterpret_cast(float_max_bits); + accum_max_ = ElementSoftmaxCompute(min_float); + accum_sum_ = ElementSoftmaxCompute(0); + } + + CUTLASS_DEVICE + ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) { + ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SoftmaxFragment::kElements; ++i) { + sum_ += ElementSoftmaxCompute(accum[i]); + } + + return sum_; + } + + CUTLASS_DEVICE + ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute sum_) { + // ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SoftmaxFragment::kElements; ++i) { + sum_ += ElementSoftmaxCompute(accum[i]); + } + + return sum_; + } + + CUTLASS_DEVICE + ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) { + ElementSoftmaxCompute max_ = accum[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < SoftmaxFragment::kElements; ++i) { + max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); + } + + return max_; + } + + CUTLASS_DEVICE + ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < SoftmaxFragment::kElements; ++i) { + max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); + } + + return max_; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 943cd5d0..3fafcbc4 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -39,11 +39,12 @@ #pragma once -#include #if defined(__CUDACC_RTC__) #include +#include #else #include +#include #endif #include "cutlass/cutlass.h" diff --git a/examples/35_gemm_softmax/epilogue_with_visitor.h b/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h similarity index 99% rename from examples/35_gemm_softmax/epilogue_with_visitor.h rename to include/cutlass/epilogue/threadblock/epilogue_with_visitor.h index aa322553..21cb1ef9 100644 --- a/examples/35_gemm_softmax/epilogue_with_visitor.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_visitor.h @@ -121,6 +121,7 @@ public: /// Called after accumulators have been exchanged for each accumulator vector CUTLASS_DEVICE void visit( + int iter_idx, int row_idx, int column_idx, int frag_idx, @@ -128,7 +129,7 @@ public: } - /// Called at the start of a row + /// Called at the end of a row CUTLASS_DEVICE void end_row(int row_idx) { @@ -325,6 +326,7 @@ public: } visitor.visit( + iter_idx, row_idx, col_idx, idx, diff --git a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index 83b07b99..c57abca4 100644 --- a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -391,10 +391,10 @@ struct OutputTileOptimalThreadMap { 1>; /// Initial offset function - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static MatrixCoord initial_offset(int thread_idx) { - int warp_idx = thread_idx / kWarpSize; + int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); int lane_idx = thread_idx % kWarpSize; // Compute warp location @@ -419,7 +419,7 @@ struct OutputTileOptimalThreadMap { return MatrixCoord( cluster_offset + group_offset + row_offset + lane_row_offset, - (column_offset + lane_col_offset) * kElementsPerAccess + column_offset + lane_col_offset * kElementsPerAccess ); } @@ -461,10 +461,10 @@ struct OutputTileOptimalThreadMap { static int const kThreads = Threads; /// Function to compute each thread's initial offset - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE static MatrixCoord initial_offset(int thread_idx) { - int warp_idx = thread_idx / kWarpSize; + int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); int lane_idx = thread_idx % kWarpSize; // Compute warp location @@ -489,7 +489,7 @@ struct OutputTileOptimalThreadMap { MatrixCoord coord( cluster_offset + group_offset + row_offset + lane_row_offset, - (column_offset + lane_col_offset) * kElementsPerAccess + column_offset + lane_col_offset * kElementsPerAccess ); return coord; diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 1cc8d1ce..d70a1989 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -43,6 +43,7 @@ #include "cutlass/array.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/tensor.h" +#include "cutlass/layout/permute.h" #include "cutlass/matrix_shape.h" #include "cutlass/tensor_ref.h" #include "cutlass/transform/pitch_linear_thread_map.h" @@ -70,6 +71,7 @@ template < typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) typename Element_, ///< Element data type bool ScatterD = false, ///< Scatter D operand or not + typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not bool UseCUDAStore = false > class PredicatedTileIterator { @@ -173,9 +175,12 @@ private: /// Parameters structure containing reference and precomputed state. PredicatedTileIteratorParams params_; - /// Byte-level pointer + /// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load(). uint8_t *byte_pointer_; + /// Byte-level pointer for store(). Due to PermuteD Op, store_byte_pointer_ may be with different address computation compared to byte_pointer_. + uint8_t *store_byte_pointer_; + /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -196,6 +201,11 @@ private: /// Scatter indices int const *indices_; + + /// Whether to perform Permute Op + bool PermuteD; + /// PermuteDLayout + mutable PermuteDLayout permute_layout_; // // Static asserts about internal strides @@ -255,7 +265,7 @@ public: mask_.clear(); } - // Initialize pointer + // Initialize byte_pointer_ byte_pointer_ = reinterpret_cast(pointer) + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; @@ -265,6 +275,19 @@ public: LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; } + // store_byte_pointer_ is set to be the same with byte_pointer_ unless PermuteD is used. + store_byte_pointer_ = byte_pointer_; + + // Initialize PermuteD. If PermuteD is true, store_byte_pointer_ is initialized accordingly. + if (platform::is_same::value) { + PermuteD = false; + }else{ + PermuteD = true; + store_byte_pointer_ = reinterpret_cast(pointer); + permute_layout_ = PermuteDLayout(extent, + params_.stride * kElementsPerAccess / sizeof(AccessType)); + } + // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } @@ -272,6 +295,7 @@ public: /// Adds a pointer offset in units of Element CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset) { + store_byte_pointer_ += pointer_offset * sizeof_bits::value / 8; byte_pointer_ += pointer_offset * sizeof_bits::value / 8; } @@ -353,7 +377,7 @@ public: /// Stores a fragment to memory CUTLASS_DEVICE void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { - uint8_t *byte_pointer = byte_pointer_; + uint8_t *byte_pointer = store_byte_pointer_; AccessType const *frag_ptr = reinterpret_cast(&frag); CUTLASS_PRAGMA_UNROLL @@ -388,21 +412,38 @@ public: bool guard = row_guard && mask_.predicates[column]; + int col_offset = column * ThreadMap::Delta::kColumn; + + if (PermuteD) { + int col = col_offset + thread_start_column_; + int row = row_offset + thread_start_row_; + + TensorCoord init_coord(row, col); + + // Locate memory_pointer + memory_pointer = reinterpret_cast(byte_pointer + byte_offset + + permute_layout_(init_coord) * sizeof(AccessType) / kElementsPerAccess); + } + if (UseCUDAStore) { if (guard) { - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + memory_pointer[0] = frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; } } else { cutlass::arch::global_store( frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + (void *)&memory_pointer[0], guard); } + + if (!PermuteD) { + memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); + } } if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { + if (!ScatterD && !PermuteD) { byte_pointer += params_.increment_row; } } @@ -605,6 +646,10 @@ public: ++state_[0]; + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_row; + } + if (!ScatterD) { byte_pointer_ += params_.advance_row; } @@ -616,6 +661,7 @@ public: state_[0] = 0; ++state_[1]; byte_pointer_ += params_.advance_group; + store_byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -625,6 +671,7 @@ public: state_[1] = 0; ++state_[2]; byte_pointer_ += params_.advance_cluster; + store_byte_pointer_ += params_.advance_group; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; @@ -632,6 +679,7 @@ public: if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; byte_pointer_ += params_.advance_tile; + store_byte_pointer_ += params_.advance_group; } } } diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h index 792c1697..e3eaf55e 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -35,6 +35,8 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index dd6c0406..99f64c67 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -908,3 +908,4 @@ T absolute_value(T x) { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/gemm/device/base_grouped.h b/include/cutlass/gemm/device/base_grouped.h new file mode 100644 index 00000000..e8a93415 --- /dev/null +++ b/include/cutlass/gemm/device/base_grouped.h @@ -0,0 +1,478 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Base device-level grouped kernel. +*/ + +#pragma once + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class BaseGrouped { +public: + + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + +protected: + + /// Kernel parameters object + typename BaseKernel::Params params_; + +private: + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) { + cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaMemcpy() returned error " + << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const &args, int32_t tile_count, void* workspace) { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute(args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, const std::vector& indices) { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (int i = 0; i < indices.size(); ++i) { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + +public: + + /// Constructs the GEMM. + BaseGrouped() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const &problem) { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const &args) { + if (args.host_problem_sizes == nullptr) { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + return BaseKernel::ProblemVisitor::get_workspace_size(args.host_problem_sizes, + args.problem_count, + args.threadblock_count); + } else { + return 0; + } + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + + CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + Kernel, + BaseKernel::kThreadCount, + smem_size); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, + cutlass::gemm::GemmCoord* problem_sizes_ptr, + int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, + int64_t* ldc_host_ptr, + int64_t* ldd_host_ptr, + int64_t* offset_A_ptr, + int64_t* offset_B_ptr, + int64_t* offset_C_ptr, + int64_t* offset_D_ptr) + { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { + return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); + }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient(const cutlass::gemm::GemmCoord* problem_sizes_ptr=nullptr, + int problem_count=0, + int available_sm_count=-1) { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(result)); + return 0; + } + + result = cudaGetDeviceProperties(&properties, device_idx); + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDeviceProperties() returned error " + << cudaGetErrorString(result)); + return 0; + } + + bool override_sm_count = (available_sm_count < 0 || available_sm_count > properties.multiProcessorCount); + if (override_sm_count) { + available_sm_count = properties.multiProcessorCount; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return min(total_tiles, occupancy_based_block_count); + } + + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + params_ = typename BaseKernel::Params(args, workspace, tile_count); + } else { + params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) { + return status; + } + + params_.update(args, workspace, tile_count); + } else { + params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // + // Configure grid and block dimensions + // + + if (!params_.problem_visitor.problem_count) { + return Status::kSuccess; + } + + dim3 grid(params_.threadblock_count, 1, 1); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + // + // Launch kernel + // + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Initializes and runs the kernel. + Status operator()( + Arguments const &args, + void *workspace, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index 045fba74..179a4b2d 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -45,6 +45,8 @@ #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/layout/permute.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -225,7 +227,9 @@ template < /// Gather operand B by using an index array bool GatherB = false, /// Scatter result D by using an index array - bool ScatterD = false> + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> class Gemm { public: @@ -280,7 +284,8 @@ class Gemm { SharedMemoryClearOption::kNone, GatherA, GatherB, - ScatterD + ScatterD, + PermuteDLayout >::GemmKernel; /// Argument structure @@ -559,14 +564,16 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > class Gemm { + Operator_, GatherA, GatherB, ScatterD, PermuteDLayout> { public: using ElementA = ElementA_; @@ -617,7 +624,8 @@ class Gemm; using UnderlyingArguments = typename UnderlyingOperator::Arguments; diff --git a/include/cutlass/gemm/device/gemm_grouped.h b/include/cutlass/gemm/device/gemm_grouped.h index 628a56b0..ca21b562 100644 --- a/include/cutlass/gemm/device/gemm_grouped.h +++ b/include/cutlass/gemm/device/gemm_grouped.h @@ -28,29 +28,14 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! +/*! \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. + \brief Device-level grouped GEMM. */ #pragma once -#include - -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/arch/arch.h" -#include "cutlass/device_kernel.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/gemm/kernel/gemm_universal.h" - -#include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/device/default_gemm_configuration.h" - -#include "cutlass/trace.h" +#include "cutlass/gemm/device/base_grouped.h" //////////////////////////////////////////////////////////////////////////////// @@ -62,220 +47,9 @@ namespace device { /// GEMM Grouped template -class GemmGrouped { +class GemmGrouped : public BaseGrouped { public: - using GemmKernel = GemmKernel_; - - using ElementA = typename GemmKernel::ElementA; - using LayoutA = typename GemmKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = GemmKernel::kTransformA; - static int const kAlignmentA = GemmKernel::kAlignmentA; - - using ElementB = typename GemmKernel::ElementB; - using LayoutB = typename GemmKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = GemmKernel::kTransformB; - static int const kAlignmentB = GemmKernel::kAlignmentB; - - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename GemmKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - static int const kAlignmentC = GemmKernel::kAlignmentC; - - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - - using Operator = typename GemmKernel::Operator; - using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename WarpMmaOperator::MathOperator; - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - using ThreadblockShape = typename GemmKernel::Mma::Shape; - using WarpShape = typename GemmKernel::WarpShape; - using InstructionShape = typename GemmKernel::InstructionShape; - static int const kStages = GemmKernel::Mma::kStages; - - /// Argument structure - using Arguments = typename GemmKernel::Arguments; - -protected: - - /// Kernel parameters object - typename GemmKernel::Params params_; - -public: - - /// Constructs the GEMM. - GemmGrouped() { } - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const &args) { - - return GemmKernel::can_implement(args); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const &args) { - - // This kerenl does not utilize a workspace - return size_t(); - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const &args) { - - return dim3(args.threadblock_count, 1, 1); - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) { - - CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) { - result = cudaFuncSetAttribute(Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST( - " cudaFuncSetAttribute() returned error " - << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, - Kernel, - GemmKernel::kThreadCount, - smem_size); - - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " - << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - - CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - // Workspace - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) { - return Status::kErrorWorkspaceNull; - } - - // Initialize the Params structure - params_ = typename GemmKernel::Params(args, workspace); - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) { - cudaError_t result = cudaFuncSetAttribute(Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const &args, void *workspace = nullptr) { - - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) { - return Status::kErrorWorkspaceNull; - } - - params_.update(args, workspace); - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) { - - // - // Configure grid and block dimensions - // - - if (!params_.problem_visitor.problem_count) { - return Status::kSuccess; - } - - dim3 grid(params_.threadblock_count, 1, 1); - dim3 block(GemmKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - // - // Launch kernel - // - - // Launch - cutlass::Kernel<<>>(params_); - - // - // Query for errors - // - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) { - return run(stream); - } - - /// Runs the kernel using initialized state. - Status operator()( - Arguments const &args, - void *workspace = nullptr, - cudaStream_t stream = nullptr) { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) { - status = run(stream); - } - - return status; - } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h b/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h new file mode 100644 index 00000000..a4b75ebc --- /dev/null +++ b/include/cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Device-level GEMM with layernorm elementwise operations fused in mainloop +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for Scale/Bias vectors + typename ElementScaleBias_, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator +> +class GemmLayernormMainloopFusion : + public GemmUniversalBase< + typename kernel::DefaultGemmLayernormMainloopFusion< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementScaleBias_, + LayoutScaleBias_, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + > { + + public: + + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + using Base = GemmUniversalBase< + typename kernel::DefaultGemmLayernormMainloopFusion< + ElementA_, + LayoutA_, + AlignmentA, + ElementB_, + LayoutB_, + AlignmentB, + ElementScaleBias_, + LayoutScaleBias_, + ElementC_, + LayoutC_, + ElementAccumulator_, + OperatorClass_, + ArchTag_, + ThreadblockShape_, + WarpShape_, + InstructionShape_, + EpilogueOutputOp_, + ThreadblockSwizzle_, + Stages, + Operator_, + SharedMemoryClearOption::kNone + >::GemmKernel + >; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Parital specialization for column-major output exchanges problem size and operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for Scale/Bias vectors + typename ElementScaleBias_, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_ +> +class GemmLayernormMainloopFusion { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementScaleBias = ElementScaleBias_; + using LayoutScaleBias = LayoutScaleBias_; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + + using UnderlyingOperator = typename GemmLayernormMainloopFusion< + ElementB, + typename layout::LayoutTranspose::type, + ElementA, + typename layout::LayoutTranspose::type, + ElementScaleBias, + LayoutScaleBias, + ElementC, + layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + kAlignmentB, + kAlignmentA, + Operator + >::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + +private: + + UnderlyingOperator underlying_operator_; + +public: + + /// Constructs the GEMM. + GemmLayernormMainloopFusion() { } + + /// Helper to construct a transposed equivalent for the underlying GEMM operator + static Arguments to_underlying_arguments(Arguments const &args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + return underlying_operator_.update(to_underlying_arguments(args), workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal.h b/include/cutlass/gemm/device/gemm_universal.h index ddd997a8..f2a09cd2 100644 --- a/include/cutlass/gemm/device/gemm_universal.h +++ b/include/cutlass/gemm/device/gemm_universal.h @@ -47,6 +47,8 @@ #include "cutlass/gemm/device/default_gemm_configuration.h" #include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/layout/permute.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -123,7 +125,9 @@ template < /// Gather operand B by using an index array bool GatherB = false, /// Scatter result D by using an index array - bool ScatterD = false + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute > class GemmUniversal : public GemmUniversalBase< @@ -151,7 +155,8 @@ class GemmUniversal : SharedMemoryClearOption::kNone, GatherA, GatherB, - ScatterD + ScatterD, + PermuteDLayout >::GemmKernel > { @@ -198,7 +203,8 @@ class GemmUniversal : SharedMemoryClearOption::kNone, GatherA, GatherB, - ScatterD + ScatterD, + PermuteDLayout >::GemmKernel >; @@ -255,14 +261,16 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > class GemmUniversal { + Operator_, TransformA, TransformB, GatherA, GatherB, ScatterD, PermuteDLayout> { public: using ElementA = ElementA_; @@ -313,7 +321,8 @@ class GemmUniversal::Base; using GemmKernel = typename UnderlyingOperator::GemmKernel; diff --git a/include/cutlass/gemm/device/rank_2k_grouped.h b/include/cutlass/gemm/device/rank_2k_grouped.h new file mode 100644 index 00000000..3c7ab614 --- /dev/null +++ b/include/cutlass/gemm/device/rank_2k_grouped.h @@ -0,0 +1,63 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Device-level grouped Rank2K. +*/ + +#pragma once + +#include "cutlass/gemm/device/base_grouped.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Rank2K Grouped +template +class Rank2KGrouped : public BaseGrouped { +public: + using Rank2Kkernel = Rank2Kkernel_; + static const cutlass::FillMode kFillModeC = Rank2Kkernel::kFillModeC; + static const cutlass::BlasMode kBlasMode = Rank2Kkernel::kBlasMode; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm.h b/include/cutlass/gemm/kernel/default_gemm.h index 8b433d24..38efae7e 100644 --- a/include/cutlass/gemm/kernel/default_gemm.h +++ b/include/cutlass/gemm/kernel/default_gemm.h @@ -65,6 +65,8 @@ #include "cutlass/epilogue/threadblock/default_epilogue_simt.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/layout/permute.h" + #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" #endif //CUTLASS_ARCH_WMMA_ENABLED @@ -125,6 +127,8 @@ template < bool GatherB = false, /// Scatter result D by using an index array bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, /// typename Enable = void > @@ -177,13 +181,15 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > struct DefaultGemm { + Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, PermuteDLayout> { static_assert(platform::is_same::value || platform::is_same>::value, @@ -202,14 +208,14 @@ struct DefaultGemm::Epilogue; + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount>::Epilogue; - using Epilogue = typename cutlass::platform::conditional::value, + using Epilogue = typename platform::conditional::value, RegularEpilogue, Affine2Epilogue>::type; @@ -258,7 +264,9 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > struct DefaultGemm< ElementA, LayoutA, kAlignmentA, @@ -278,7 +286,8 @@ struct DefaultGemm< SharedMemoryClear, GatherA, GatherB, - ScatterD + ScatterD, + PermuteDLayout > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -313,7 +322,8 @@ struct DefaultGemm< kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount, - ScatterD + ScatterD, + PermuteDLayout >::Epilogue; /// Define the kernel-level GEMM operator. @@ -493,7 +503,9 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > struct DefaultGemm< ElementA, LayoutA, kAlignmentA, @@ -513,7 +525,8 @@ struct DefaultGemm< SharedMemoryClear, GatherA, GatherB, - ScatterD + ScatterD, + PermuteDLayout > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -548,7 +561,8 @@ struct DefaultGemm< kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount, - ScatterD + ScatterD, + PermuteDLayout >::Epilogue; /// Define the kernel-level GEMM operator. @@ -598,7 +612,9 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > struct DefaultGemm< ElementA, @@ -624,6 +640,7 @@ struct DefaultGemm< GatherA, GatherB, ScatterD, + PermuteDLayout, typename platform::enable_if< ! platform::is_same::value >::type > { static_assert(platform::is_same::value @@ -661,7 +678,8 @@ struct DefaultGemm< typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess, - ScatterD + ScatterD, + PermuteDLayout >::Epilogue; using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< @@ -672,7 +690,7 @@ struct DefaultGemm< kEpilogueElementsPerAccess >::Epilogue; - using Epilogue = typename cutlass::platform::conditional::value, + using Epilogue = typename platform::conditional::value, RegularEpilogue, Affine2Epilogue>::type; @@ -723,7 +741,9 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > struct DefaultGemm { + ScatterD, + PermuteDLayout> { static_assert(platform::is_same::value || platform::is_same>::value, @@ -769,7 +790,8 @@ struct DefaultGemm::Epilogue; using Affine2Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueSimtAffineRankN< @@ -780,7 +802,7 @@ struct DefaultGemm::Epilogue; - using Epilogue = typename cutlass::platform::conditional::value, + using Epilogue = typename platform::conditional::value, RegularEpilogue, Affine2Epilogue>::type; diff --git a/include/cutlass/gemm/kernel/default_gemm_grouped.h b/include/cutlass/gemm/kernel/default_gemm_grouped.h index ec9cac32..2ff7f88c 100644 --- a/include/cutlass/gemm/kernel/default_gemm_grouped.h +++ b/include/cutlass/gemm/kernel/default_gemm_grouped.h @@ -54,6 +54,8 @@ #include "cutlass/gemm/kernel/default_gemm_complex.h" #include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/layout/permute.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -101,12 +103,16 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, /// Operation performed by GEMM typename Operator = typename device::DefaultGemmConfiguration< OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, ElementAccumulator>::Operator, /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, /// typename Enable = void > @@ -152,10 +158,14 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, /// Operation performed by GEMM typename Operator, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout > struct DefaultGemmGrouped< ElementA, @@ -177,9 +187,11 @@ struct DefaultGemmGrouped< EpilogueOutputOp, ThreadblockSwizzle, Stages, + GroupScheduleMode_, Operator, SharedMemoryClear, - typename std::enable_if< ! cutlass::is_complex::value>::type + PermuteDLayout, + typename platform::enable_if< ! cutlass::is_complex::value>::type > { // If true, we must construct a 'transposed-and-exchanged' Mma operator. @@ -219,7 +231,11 @@ struct DefaultGemmGrouped< Stages, true, Operator, - SharedMemoryClear + SharedMemoryClear, + false, /*GatherA*/ + false, /*GatherB*/ + false, /*ScatterD*/ + PermuteDLayout >::GemmKernel; /// Define the kernel in terms of the default kernel @@ -227,6 +243,7 @@ struct DefaultGemmGrouped< typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle, + GroupScheduleMode_, kInternalTranspose >; }; @@ -276,6 +293,8 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, /// Operation performed by GEMM typename Operator, /// Use zfill or predicate for out-of-bound cp.async @@ -301,9 +320,11 @@ struct DefaultGemmGrouped< EpilogueOutputOp, ThreadblockSwizzle, Stages, + GroupScheduleMode_, Operator, SharedMemoryClear, - typename std::enable_if::value>::type + layout::NoPermute, /*PermuteDLayout*/ + typename platform::enable_if::value>::type > { // If true, we must construct a 'transposed-and-exchanged' Mma operator. @@ -349,6 +370,7 @@ struct DefaultGemmGrouped< typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle, + GroupScheduleMode_, kInternalTranspose >; }; diff --git a/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h b/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h new file mode 100644 index 00000000..05756980 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level softmax-grouped-GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h" + +#include "cutlass/layout/permute.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for Scale/Bias vectors + typename ElementScaleBias_, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, + ElementAccumulator>::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone + > +struct DefaultGemmGroupedSoftmaxMainloopFusion { + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments< + ElementA_, + LayoutA_, + ComplexTransform::kNone, + kAlignmentA, + ElementB_, + LayoutB_, + ComplexTransform::kNone, + kAlignmentB, + LayoutC_, + kInternalTranspose + >; + +private: + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMmaSoftmaxMainloopFusion< + typename MapArguments::ElementA, typename MapArguments::LayoutA, MapArguments::kAlignmentA, + typename MapArguments::ElementB, typename MapArguments::LayoutB, MapArguments::kAlignmentB, + ElementScaleBias_, LayoutScaleBias_, ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, Stages, kInternalTranspose, + Operator, false, SharedMemoryClear>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + +public: + using GemmKernel = kernel::GemmGroupedSoftmaxMainloopFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h b/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h new file mode 100644 index 00000000..db8a2b9a --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_layernorm_mainloop_fusion.h @@ -0,0 +1,137 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for Scale/Bias vectors + typename ElementScaleBias, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +struct DefaultGemmLayernormMainloopFusion { + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultMmaLayernormMainloopFusion< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementScaleBias, LayoutScaleBias, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, SharedMemoryClear>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::GemmLayernormMainloopFusion; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/default_gemm_universal.h b/include/cutlass/gemm/kernel/default_gemm_universal.h index 45b2c769..be9634e9 100644 --- a/include/cutlass/gemm/kernel/default_gemm_universal.h +++ b/include/cutlass/gemm/kernel/default_gemm_universal.h @@ -52,6 +52,8 @@ #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/layout/permute.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -109,6 +111,8 @@ template < bool GatherB = false, /// Scatter result D by using an index array bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, /// typename Enable = void > @@ -163,7 +167,9 @@ template < /// Gather operand B by using an index array bool GatherB, /// Scatter result D by using an index array - bool ScatterD + bool ScatterD, + /// Permute result D + typename PermuteDLayout > struct DefaultGemmUniversal< ElementA, @@ -190,6 +196,7 @@ struct DefaultGemmUniversal< GatherA, GatherB, ScatterD, + PermuteDLayout, typename platform::enable_if< ! cutlass::is_complex::value>::type > { @@ -216,7 +223,8 @@ struct DefaultGemmUniversal< SharedMemoryClear, GatherA, GatherB, - ScatterD + ScatterD, + PermuteDLayout >::GemmKernel; /// Define the kernel in terms of the default kernel @@ -302,6 +310,7 @@ struct DefaultGemmUniversal< false, false, false, + layout::NoPermute, typename platform::enable_if::value>::type > { diff --git a/include/cutlass/gemm/kernel/default_rank_2k_grouped.h b/include/cutlass/gemm/kernel/default_rank_2k_grouped.h new file mode 100644 index 00000000..5b6db616 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_rank_2k_grouped.h @@ -0,0 +1,355 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level grouped Rank2K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/rank_2k_transpose_operands.h" +#include "cutlass/gemm/kernel/default_rank_2k.h" +#include "cutlass/gemm/kernel/default_rank_2k_complex.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Blas3 computation mode + BlasMode BlasMode_ = BlasMode::kSymmetric, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// + typename Enable = void + > +struct DefaultRank2KGrouped; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued grouped Rank2K +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Blas3 computation mode + BlasMode BlasMode_, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ + > +struct DefaultRank2KGrouped::value>::type +> { + // If true, we must construct a 'transposed-and-exchanged' Rank2K operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::Rank2KMapArguments< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + LayoutC, + FillModeC, + kInternalTranspose + >; + + // Define the default grouped Rank2K kernel + using DefaultRank2Kkernel = typename kernel::DefaultRank2K< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementC, + typename MapArguments::LayoutC, + MapArguments::kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + false, // SplitKSerial + Operator, + BlasMode_ + >::Rank2Kkernel; + + /// Define the kernel in terms of the default kernel + using Rank2Kkernel = kernel::Rank2KGrouped< + typename DefaultRank2Kkernel::Mma1, + typename DefaultRank2Kkernel::Mma2, + typename DefaultRank2Kkernel::Epilogue, + ThreadblockSwizzle, + TransformA, + TransformB, + DefaultRank2Kkernel::kFillModeC, + DefaultRank2Kkernel::kBlasMode, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Complex-valued grouped Rank2K +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Fill Mode for C (kLower or kUpper) + FillMode FillModeC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Blas3 computation mode + BlasMode BlasMode_, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ + > +struct DefaultRank2KGrouped::value>::type +> { + // If true, we must construct a 'transposed-and-exchanged' Rank2K operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::Rank2KMapArguments< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + LayoutC, + FillModeC, + kInternalTranspose + >; + + // Define the default grouped Rank2K kernel + using DefaultRank2Kkernel = typename kernel::DefaultRank2KComplex< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + ElementC, + typename MapArguments::LayoutC, + MapArguments::kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MapArguments::kTransformA, + MapArguments::kTransformB, + Operator, + false, // SplitKSerial + BlasMode_ + >::Rank2Kkernel; + + /// Define the kernel in terms of the default kernel + /// Pass through the user-provided TransformA and TransformB so as to + /// correctly set public-facing TransformA and TransformB in kernel::Rank2KGrouped. + /// This is needed because kernel::DefaultRank2KComplex may change TransformA and + /// TransformB that become template arguments to Mma1 and Mma2. + using Rank2Kkernel = kernel::Rank2KGrouped< + typename DefaultRank2Kkernel::Mma1, + typename DefaultRank2Kkernel::Mma2, + typename DefaultRank2Kkernel::Epilogue, + ThreadblockSwizzle, + TransformA, + TransformB, + DefaultRank2Kkernel::kFillModeC, + DefaultRank2Kkernel::kBlasMode, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_grouped.h b/include/cutlass/gemm/kernel/gemm_grouped.h index ceca1f7d..c02d3ff9 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped.h +++ b/include/cutlass/gemm/kernel/gemm_grouped.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief + \brief Problem visitor for grouped GEMMs */ #pragma once @@ -45,6 +45,7 @@ #include "cutlass/layout/matrix.h" #include "cutlass/trace.h" #include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -54,168 +55,11 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Visitor class to abstract away the algorithm for iterating over tiles -template -struct GemmGroupedProblemVisitor { - - static bool const kTransposed = Transposed; - - struct Params { - cutlass::gemm::GemmCoord const *problem_sizes; - int32_t problem_count; - - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - Params(): problem_sizes(nullptr), problem_count(0) { } - - /// Ctor - CUTLASS_HOST_DEVICE - Params( - cutlass::gemm::GemmCoord const *problem_sizes, - int32_t problem_count - ): - problem_sizes(problem_sizes), - problem_count(problem_count) - {} - - }; - - struct SharedStorage { - // - // Nothing for now. As an optimization step, we could consider parallel - // argmin or prefix sums across the block. - // - }; - - // - // Data members - // - - Params const ¶ms; - SharedStorage &shared_storage; - cutlass::MatrixCoord threadblock_shape; - - int64_t tile_idx; - int64_t tile_count_sum; - int64_t problem_tile_start; - int32_t problem_idx; - - // - // Methods - // - CUTLASS_DEVICE - GemmGroupedProblemVisitor( - Params const ¶ms_, - SharedStorage &shared_storage_, - cutlass::MatrixCoord threadblock_shape_, - int32_t block_idx - ): - shared_storage(shared_storage_), - params(params_), - threadblock_shape(threadblock_shape_), - tile_idx(block_idx), - tile_count_sum(0), - problem_idx(0) - { - - cutlass::gemm::GemmCoord problem = problem_size(); - cutlass::gemm::GemmCoord grid = grid_shape(problem); - - problem_tile_start = 0; - tile_count_sum = grid.m() * grid.n(); - } - - /// Get the grid shape - CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape( - cutlass::gemm::GemmCoord problem, - cutlass::MatrixCoord const & block_shape) { - - return cutlass::gemm::GemmCoord( - ((problem.m() - 1 + block_shape.row()) / block_shape.row()), - ((problem.n() - 1 + block_shape.column()) / block_shape.column()), - 1); - } - - /// Get the grid shape - CUTLASS_DEVICE - cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const &problem) const { - return grid_shape(problem, threadblock_shape); - } - - /// Returns true if there is a tile to compute - CUTLASS_DEVICE - bool next_tile() { - - if (tile_idx < tile_count_sum) { - return true; - } - - do { - ++problem_idx; - - if (problem_idx >= params.problem_count) { - return false; - } - - cutlass::gemm::GemmCoord problem = problem_size(); - cutlass::gemm::GemmCoord grid = grid_shape(problem); - - int64_t tile_count = grid.m() * grid.n(); - - problem_tile_start = tile_count_sum; - tile_count_sum += tile_count; - - } while (tile_count_sum <= tile_idx); - - return true; - } - - /// Gets the global tile index - CUTLASS_HOST_DEVICE - int64_t tile_index() const { - return tile_idx; - } - - /// Gets the index of the problem - CUTLASS_HOST_DEVICE - int32_t problem_index() const { - return problem_idx; - } - - /// Returns the problem size for the current problem - CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size() const { - GemmCoord problem = params.problem_sizes[problem_idx]; - - if (kTransposed) { - swap(problem.m(), problem.n()); - } - - return problem; - } - - CUTLASS_HOST_DEVICE - int64_t threadblock_index() const { - return tile_idx - problem_tile_start; - } - - CUTLASS_DEVICE - void advance(int32_t grid_size) { - tile_idx += grid_size; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform bool Transposed = false > struct GemmGrouped { @@ -225,6 +69,7 @@ public: using Epilogue = Epilogue_; using EpilogueOutputOp = typename Epilogue::OutputOp; using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; static bool const kTransposed = Transposed; // Optional transpose @@ -270,6 +115,13 @@ public: using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; + using ProblemVisitor = GemmGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount, + kTransposed>; + // // Structures // @@ -290,13 +142,16 @@ public: ElementA ** ptr_A; ElementB ** ptr_B; ElementC ** ptr_C; - ElementC ** ptr_D; + ElementC ** ptr_D; typename LayoutA::Stride::LongIndex *lda; typename LayoutB::Stride::LongIndex *ldb; typename LayoutC::Stride::LongIndex *ldc; typename LayoutC::Stride::LongIndex *ldd; + // Only used by device-level operator + GemmCoord *host_problem_sizes; + // // Methods // @@ -304,7 +159,7 @@ public: /// Default ctor CUTLASS_HOST_DEVICE Arguments(): - problem_count(0), + problem_count(0), threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), @@ -313,7 +168,8 @@ public: lda(nullptr), ldb(nullptr), ldc(nullptr), - ldd(nullptr) + ldd(nullptr), + host_problem_sizes(nullptr) { } @@ -328,11 +184,12 @@ public: ElementA ** ptr_A, ElementB ** ptr_B, ElementC ** ptr_C, - ElementC ** ptr_D, + ElementC ** ptr_D, typename LayoutA::Stride::LongIndex *lda, typename LayoutB::Stride::LongIndex *ldb, typename LayoutC::Stride::LongIndex *ldc, - typename LayoutC::Stride::LongIndex *ldd + typename LayoutC::Stride::LongIndex *ldd, + GemmCoord *host_problem_sizes=nullptr ): problem_sizes(problem_sizes), problem_count(problem_count), @@ -345,7 +202,8 @@ public: lda(lda), ldb(ldb), ldc(ldc), - ldd(ldd) + ldd(ldd), + host_problem_sizes(host_problem_sizes) { } @@ -358,7 +216,7 @@ public: /// Parameters structure struct Params { - typename GemmGroupedProblemVisitor::Params problem_visitor; + typename ProblemVisitor::Params problem_visitor; int threadblock_count; typename EpilogueOutputOp::Params output_op; @@ -373,7 +231,6 @@ public: typename LayoutC::Stride::LongIndex *ldc; typename LayoutC::Stride::LongIndex *ldd; - // // Methods // @@ -391,8 +248,10 @@ public: { } CUTLASS_HOST_DEVICE - Params(Arguments const &args, void *workspace = nullptr): - problem_visitor(args.problem_sizes, args.problem_count), + Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0): + problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), threadblock_count(args.threadblock_count), output_op(args.output_op), ptr_A(args.ptr_A), @@ -410,9 +269,11 @@ public: CUTLASS_HOST_DEVICE void update( Arguments const &args, - void *workspace = nullptr) { + void *workspace = nullptr, + int tile_count = 0) { - problem_visitor = typename GemmGroupedProblemVisitor::Params(args.problem_sizes, args.problem_count); + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, + workspace, tile_count); threadblock_count = args.threadblock_count; output_op = args.output_op; ptr_A = args.ptr_A; @@ -427,10 +288,14 @@ public: }; /// Shared memory storage structure - union SharedStorage { - typename GemmGroupedProblemVisitor::SharedStorage problem_visitor; - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; }; public: @@ -476,24 +341,23 @@ public: // // Problem visitor. // - GemmGroupedProblemVisitor problem_visitor( - params.problem_visitor, - shared_storage.problem_visitor, - {Mma::Shape::kM, Mma::Shape::kN}, + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, blockIdx.x); // Outer 'persistent' loop to iterate over tiles while (problem_visitor.next_tile()) { - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_index()); + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); cutlass::gemm::GemmCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, - int(cta_idx % grid_shape.n()) * Mma::Shape::kN, + int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); // Load element pointers. Exchange pointers and strides if working on the transpose @@ -547,7 +411,7 @@ public: // // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; @@ -597,7 +461,7 @@ public: ); Epilogue epilogue( - shared_storage.epilogue, + shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); diff --git a/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h new file mode 100644 index 00000000..51fec120 --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +// Helper for correctly representing problem sizes in grouped kernels +template +struct GemmGroupedProblemSizeHelper { + + static bool const kTransposed = Transposed; + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { + if (kTransposed) { + swap(problem.m(), problem.n()); + } + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + return grid.m() * grid.n(); + } +}; + +} // namespace detail + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmGroupedProblemVisitor : public GroupedProblemVisitor< + detail::GemmGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base (params_, shared_storage_, block_idx) + {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h b/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h new file mode 100644 index 00000000..4b2d90bb --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_grouped_softmax_mainloop_fusion.h @@ -0,0 +1,517 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Problem visitor for grouped GEMMs with a softmax fused beforehand +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform + bool Transposed = false +> +struct GemmGroupedSoftmaxMainloopFusion { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments< + typename Mma::IteratorA::Element, + typename Mma::IteratorA::Layout, + Mma::kTransformA, + Mma::IteratorA::AccessType::kElements, + typename Mma::IteratorB::Element, + typename Mma::IteratorB::Layout, + Mma::kTransformB, + Mma::IteratorB::AccessType::kElements, + typename Mma::LayoutC, + kTransposed + >; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementScaleBias = typename Mma::IteratorNormSum::Element; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount, + kTransposed>; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord *problem_sizes; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA ** ptr_A; + ElementB ** ptr_B; + ElementC ** ptr_C; + ElementC ** ptr_D; + void ** ptr_norm; + void ** ptr_sum; + + typename LayoutA::Stride::LongIndex *lda; + typename LayoutB::Stride::LongIndex *ldb; + typename LayoutC::Stride::LongIndex *ldc; + typename LayoutC::Stride::LongIndex *ldd; + + // Only used by device-level operator + GemmCoord *host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): + problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_norm(nullptr), + ptr_sum(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + host_problem_sizes(nullptr) + { + + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord *problem_sizes, + int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + ElementA ** ptr_A, + ElementB ** ptr_B, + ElementC ** ptr_C, + ElementC ** ptr_D, + void ** ptr_norm, + void ** ptr_sum, + typename LayoutA::Stride::LongIndex *lda, + typename LayoutB::Stride::LongIndex *ldb, + typename LayoutC::Stride::LongIndex *ldc, + typename LayoutC::Stride::LongIndex *ldd, + GemmCoord *host_problem_sizes=nullptr + ): + problem_sizes(problem_sizes), + problem_count(problem_count), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + ptr_norm(ptr_norm), + ptr_sum(ptr_sum), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + host_problem_sizes(host_problem_sizes) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA ** ptr_A; + ElementB ** ptr_B; + ElementC ** ptr_C; + ElementC ** ptr_D; + + void ** ptr_norm; + void ** ptr_sum; + + typename LayoutA::Stride::LongIndex *lda; + typename LayoutB::Stride::LongIndex *ldb; + typename LayoutC::Stride::LongIndex *ldc; + typename LayoutC::Stride::LongIndex *ldd; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_norm(nullptr), + ptr_sum(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr) + { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0): + problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D), + ptr_norm(args.ptr_norm), + ptr_sum(args.ptr_sum), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldd(args.ldd) + { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, + workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_norm = args.ptr_norm; + ptr_sum = args.ptr_sum; + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmGroupedSoftmaxMainloopFusion() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static size_t get_extra_workspace_size( + Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{ + 0, + threadblock_offset.n() + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), + ptr_B, + {problem_size.k(), problem_size.n()}, + thread_idx, + tb_offset_B); + + // Construct iterator to the softmax norm/sum vector + typename Mma::IteratorNormSum iterator_norm_sum( + problem_size.m(), + static_cast(params.ptr_norm[problem_idx]), + static_cast(params.ptr_sum[problem_idx]), + thread_idx, + MatrixCoord(0, threadblock_offset.m()) + ); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_norm_sum, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC *ptr_C = params.ptr_C[problem_idx]; + ElementC *ptr_D = params.ptr_D[problem_idx]; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset.mn() + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size.mn(), + thread_idx, + threadblock_offset.mn() + ); + + Epilogue epilogue( + shared_storage.kernel.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h new file mode 100644 index 00000000..f83f5f6b --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a multistage GEMM kernel with layernorm operations fused in mainloop. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmLayernormMainloopFusion { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + using ElementScaleBias = typename Mma::IteratorVarMean::Element; + using LayoutScaleBias = typename Mma::IteratorVarMean::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_var; + void const * ptr_mean; + void const * ptr_gamma; + void const * ptr_beta; + void const * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_var; + int64_t batch_stride_mean; + int64_t batch_stride_gamma; + int64_t batch_stride_beta; + int64_t batch_stride_C; + int64_t batch_stride_D; + + typename LayoutA::Stride stride_a; + typename LayoutB::Stride stride_b; + typename LayoutScaleBias::Stride stride_var; + typename LayoutScaleBias::Stride stride_mean; + typename LayoutScaleBias::Stride stride_gamma; + typename LayoutScaleBias::Stride stride_beta; + typename LayoutC::Stride stride_c; + typename LayoutC::Stride stride_d; + + typename LayoutA::Stride::LongIndex lda; + typename LayoutB::Stride::LongIndex ldb; + typename LayoutScaleBias::Stride::LongIndex ld_var; + typename LayoutScaleBias::Stride::LongIndex ld_mean; + typename LayoutScaleBias::Stride::LongIndex ld_gamma; + typename LayoutScaleBias::Stride::LongIndex ld_beta; + typename LayoutC::Stride::LongIndex ldc; + typename LayoutC::Stride::LongIndex ldd; + + int const * ptr_gather_A_indices; + int const * ptr_gather_B_indices; + int const * ptr_scatter_D_indices; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), + ptr_var(nullptr), ptr_mean(nullptr), + ptr_gamma(nullptr), ptr_beta(nullptr), + ptr_gather_A_indices(nullptr), + ptr_gather_B_indices(nullptr), + ptr_scatter_D_indices(nullptr) {} + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_var, + void const * ptr_mean, + void const * ptr_gamma, + void const * ptr_beta, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_var, + int64_t batch_stride_mean, + int64_t batch_stride_gamma, + int64_t batch_stride_beta, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride stride_a, + typename LayoutB::Stride stride_b, + typename LayoutScaleBias::Stride stride_var, + typename LayoutScaleBias::Stride stride_mean, + typename LayoutScaleBias::Stride stride_gamma, + typename LayoutScaleBias::Stride stride_beta, + typename LayoutC::Stride stride_c, + typename LayoutC::Stride stride_d, + int const *ptr_gather_A_indices = nullptr, + int const *ptr_gather_B_indices = nullptr, + int const *ptr_scatter_D_indices = nullptr + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_var(ptr_var), ptr_mean(ptr_mean), + ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), + stride_var(stride_var), stride_mean(stride_mean), + stride_gamma(stride_gamma), stride_beta(stride_beta), + ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { + lda = 0; + ldb = 0; + ldc = 0; + ldd = 0; + ld_var = 0; + ld_mean = 0; + ld_gamma = 0; + ld_beta = 0; + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_var, + void const * ptr_mean, + void const * ptr_gamma, + void const * ptr_beta, + void const * ptr_C, + void * ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_var, + int64_t batch_stride_mean, + int64_t batch_stride_gamma, + int64_t batch_stride_beta, + int64_t batch_stride_C, + int64_t batch_stride_D, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutScaleBias::Stride::LongIndex ld_var, + typename LayoutScaleBias::Stride::LongIndex ld_mean, + typename LayoutScaleBias::Stride::LongIndex ld_gamma, + typename LayoutScaleBias::Stride::LongIndex ld_beta, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + int const *ptr_gather_A_indices = nullptr, + int const *ptr_gather_B_indices = nullptr, + int const *ptr_scatter_D_indices = nullptr + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_var(ptr_var), ptr_mean(ptr_mean), + ptr_gamma(ptr_gamma), ptr_beta(ptr_beta), + batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), + batch_stride_var(batch_stride_var), batch_stride_mean(batch_stride_mean), + batch_stride_gamma(batch_stride_gamma), batch_stride_beta(batch_stride_beta), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), + ld_var(ld_var), ld_mean(ld_mean), + ld_gamma(ld_gamma), ld_beta(ld_beta), + ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { + stride_a = make_Coord(lda); + stride_b = make_Coord(ldb); + stride_c = make_Coord(ldc); + stride_d = make_Coord(ldd); + stride_var = make_Coord(ld_var); + stride_mean = make_Coord(ld_mean); + stride_gamma = make_Coord(ld_gamma); + stride_beta = make_Coord(ld_beta); + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.stride_a, args.stride_b); + std::swap(args.batch_stride_A, args.batch_stride_B); + std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_var; + void * ptr_mean; + void * ptr_gamma; + void * ptr_beta; + void * ptr_C; + void * ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_var; + int64_t batch_stride_mean; + int64_t batch_stride_gamma; + int64_t batch_stride_beta; + int64_t batch_stride_C; + int64_t batch_stride_D; + + int * ptr_gather_A_indices; + int * ptr_gather_B_indices; + int * ptr_scatter_D_indices; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_var(nullptr), + ptr_mean(nullptr), + ptr_gamma(nullptr), + ptr_beta(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_var(0), + batch_stride_mean(0), + batch_stride_C(0), + batch_stride_D(0), + ptr_gather_A_indices(nullptr), + ptr_gather_B_indices(nullptr), + ptr_scatter_D_indices(nullptr), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda ? make_Coord_with_padding(args.lda) : args.stride_a), + params_B(args.ldb ? make_Coord_with_padding(args.ldb) : args.stride_b), + params_C(args.ldc ? make_Coord_with_padding(args.ldc) : args.stride_c), + params_D(args.ldd ? make_Coord_with_padding(args.ldd) : args.stride_d), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_var(const_cast(args.ptr_var)), + ptr_mean(const_cast(args.ptr_mean)), + ptr_gamma(const_cast(args.ptr_gamma)), + ptr_beta(const_cast(args.ptr_beta)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_var(args.batch_stride_var), + batch_stride_mean(args.batch_stride_mean), + batch_stride_gamma(args.batch_stride_gamma), + batch_stride_beta(args.batch_stride_beta), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), + ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)), + semaphore(static_cast(workspace)) { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_var = const_cast(args.ptr_var); + ptr_mean = const_cast(args.ptr_mean); + ptr_gamma = const_cast(args.ptr_gamma); + ptr_beta = const_cast(args.ptr_beta); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); + ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); + ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_var = args.batch_stride_var; + batch_stride_mean = args.batch_stride_mean; + batch_stride_gamma = args.batch_stride_gamma; + batch_stride_beta = args.batch_stride_beta; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmLayernormMainloopFusion() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + + __syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.ptr_gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B, + params.ptr_gather_B_indices); + + // Construct iterators to A var/mean vector + typename Mma::IteratorVarMean iterator_var_mean( + params.problem_size.m(), + static_cast(params.ptr_var), + static_cast(params.ptr_mean), + thread_idx, + MatrixCoord(0, (threadblock_tile_offset.m() * Mma::Shape::kM)) + ); + + // Construct iterators to A scale/bias vector + typename Mma::IteratorGammaBeta iterator_gamma_beta( + problem_size_k, + static_cast(params.ptr_gamma), + static_cast(params.ptr_beta), + thread_idx, + MatrixCoord( + 0, (threadblock_tile_offset.k() * Mma::Shape::kK) + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_var_mean, + iterator_gamma_beta, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.ptr_scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.ptr_scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/include/cutlass/gemm/kernel/grouped_problem_visitor.h new file mode 100644 index 00000000..c5321153 --- /dev/null +++ b/include/cutlass/gemm/kernel/grouped_problem_visitor.h @@ -0,0 +1,468 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Base scheduler for grouped problems +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing the type of scheduling to perform for the ProblemVisitor +enum class GroupScheduleMode { + // Perform all scheduling on device + kDeviceOnly, + // Precompute on the host the full sequence of problems to access + kHostPrecompute +}; + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseGroupedProblemVisitor { + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) : + problem_idx(problem_idx_), problem_start(problem_start_) {} + }; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes(nullptr), problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes(problem_sizes), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + }; + + Params const ¶ms; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseGroupedProblemVisitor( + Params const ¶ms_, + int32_t block_idx + ): + params(params_), + tile_idx(block_idx), + problem_tile_start(0), + problem_idx(0) + {} + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), + 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const { + return tile_idx; + } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const { + return problem_idx; + } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const { + return tile_idx - problem_tile_start; + } + + CUTLASS_DEVICE + void advance(int32_t grid_size) { + tile_idx += grid_size; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const { + GemmCoord problem = params.problem_sizes[problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ProblemSizeHelper, + typename ThreadblockShape, + GroupScheduleMode GroupScheduleMode_, + int PrefetchTileCount, + int ThreadCount +> +struct GroupedProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct GroupedProblemVisitor: public BaseGroupedProblemVisitor { + using Base = BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage {}; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + GroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + problem_ending_tile(0), + shared_storage(shared_storage_) + { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp-1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t &group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t &group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) { + cutlass::gemm::GemmCoord problem = this->params.problem_sizes[lane_problem]; + this->possibly_transpose_problem(problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp-1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = + __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Precomputes schedule on host and prefetches into shared memory +// +template +struct GroupedProblemVisitor : public BaseGroupedProblemVisitor { + static_assert(PrefetchTileCount > 0, + "GroupedProblemVisitor with GroupScheduleMode `kHost` currently requires prefetching to shared memory"); + + using Base = BaseGroupedProblemVisitor; + using Params = typename Base::Params; + using ProblemInfo = typename Base::ProblemInfo; + static bool const kRequiresPrecomputation = true; + + static int const kPrefetchTileCount = PrefetchTileCount; + static int const kThreadCount = ThreadCount; + + struct SharedStorage { + // Sequence of problem IDs and starting tiles to compute + cutlass::Array prefetched_problems; + }; + + int32_t tiles_computed; + int32_t iterations_per_block; + int32_t block_load_start; + SharedStorage &shared_storage; + ProblemInfo const *problem_info_ptr; + + // + // Methods + // + CUTLASS_DEVICE + GroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + tiles_computed(0), + shared_storage(shared_storage_), + problem_info_ptr(reinterpret_cast(params_.workspace)) + { + iterations_per_block = (params_.tile_count - 1 + gridDim.x) / gridDim.x; + block_load_start = iterations_per_block * block_idx; + // Start prefetching the first set of tiles to compute + prefetch_tiles(); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx >= this->params.tile_count) { + return false; + } + + int32_t prefetch_idx = (tiles_computed % kPrefetchTileCount); + if (prefetch_idx == 0) { + // Ensure all previous stores to shared memory have been completed + __syncthreads(); + } + + auto problem_info = shared_storage.prefetched_problems[prefetch_idx]; + ++tiles_computed; + + if ((tiles_computed % kPrefetchTileCount) == 0) { + // Begin prefetching next set of tiles. Synchronize first to ensure that + // we don't overwrite the current buffer while someone else is using it. + __syncthreads(); + prefetch_tiles(); + } + + this->problem_idx = problem_info.problem_idx; + this->problem_tile_start = problem_info.problem_start; + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + int32_t total_tiles = Base::group_tile_count(host_problem_sizes_ptr, problem_count); + int32_t entries_per_block = ((total_tiles - 1 + block_count) / block_count); + return sizeof(ProblemInfo) * entries_per_block * block_count; + } +#if !defined(__CUDACC_RTC__) + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) { + ProblemInfo* host_problem_info_ptr = reinterpret_cast(host_workspace_ptr); + int32_t total_tiles = Base::group_tile_count(host_problem_sizes_ptr, problem_count); + int32_t entries_per_block = (total_tiles - 1 + block_count) / block_count; + + int tile = 0; + int start_tile = 0; + for (int p_idx = 0; p_idx < problem_count; ++p_idx) { + auto problem = host_problem_sizes_ptr[p_idx]; + Base::possibly_transpose_problem(problem); + auto grid = Base::grid_shape(problem); + int tiles = Base::tile_count(grid); + ProblemInfo problem_info(p_idx, start_tile); + for (int i = 0; i < tiles; ++i, ++tile) { + host_problem_info_ptr[(entries_per_block * (tile % block_count)) + (tile / block_count)] = problem_info; + } + start_tile += tiles; + } + } +#endif +private: + CUTLASS_DEVICE + void prefetch_tiles() { + // TODO: Consider changing to use async copies from global to shared mem + CUTLASS_PRAGMA_UNROLL + for (int32_t i = 0; i < kPrefetchTileCount; i += kThreadCount) { + int32_t offset = threadIdx.x + i; + if (offset < kPrefetchTileCount && (tiles_computed + offset < iterations_per_block)) { + shared_storage.prefetched_problems[offset] = problem_info_ptr[block_load_start + tiles_computed + offset]; + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h new file mode 100644 index 00000000..91e7767c --- /dev/null +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -0,0 +1,711 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped Rank2K kernel. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/rank_2k_transpose_operands.h" +#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma1_, ///! Threadblock-scoped matrix multiply-accumulate (A*B^T) + typename Mma2_, ///! Threadblock-scoped matrix multiply-accumulate (B*A^T) + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + ComplexTransform OriginalTransformA_, ///! Public-facing transformation on A + ComplexTransform OriginalTransformB_, ///! Public-facing transformation on B + FillMode FillModeC_, ///! Fill Mode for C (kLower or kUpper) + BlasMode BlasMode_, ///! Blas3 computation mode + GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform + bool Transposed = false +> +struct Rank2KGrouped { +public: + + using Mma1 = Mma1_; + using Mma2 = Mma2_; + + static_assert(platform::is_same::value && + platform::is_same::value, + "Kernel-level grouped Rank2K requires that LayoutC be row major."); + + // Define generic Mma for usecases that use Kernel::Mma + using Mma = Mma1_; + + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion to reflect the original layout, + // fill mode, etc. passed in. + // + // Recall that a Rank2K operation performs (A x BT) + (B x AT) + // This is performed via: + // Mma1 = (A x BT) + // Mma2 = (B x AT) + // + // However, if C needs to be transposed, then this is changed to the following: + // Mma1 = (B x AT) + // Mma2 = (A x BT) + // + // The transformation above is achieved by swapping the Layouts/Elements/Transforms/etc. + // of A and B as they are passed into the instantiations of Mma1 and Mma2. + // + // Now, given access to only Mma1 and Mma2, as well as whether a transposition has occurred, + // we wish to retrieve the original Layouts/Elements/etc. for A and B that were passed into + // the device-level call. + // + // The logic to do this (which is made clearer by referencing the above instantiations) is as follows: + // LayoutA = kTransposed ? Mma2::LayoutA : Mma1::LayoutA + // LayoutB = kTransposed ? Mma1::LayoutA : Mma2::LayoutA + // + // We achieve this swapping by passing Mma1::*A and Mma2::*B to Rank2KMapArguments: + using MapArgumentsA = kernel::detail::Rank2KMapArguments< + typename Mma1::IteratorA::Element, + typename Mma1::IteratorA::Layout, + Mma1::kTransformA, + Mma1::IteratorA::AccessType::kElements, + typename Mma2::IteratorA::Element, + typename Mma2::IteratorA::Layout, + Mma2::kTransformA, + Mma2::IteratorA::AccessType::kElements, + typename Mma1::LayoutC, + FillModeC_, + kTransposed + >; + + using ElementA = typename MapArgumentsA::ElementA; + using LayoutA = typename MapArgumentsA::LayoutA; + static int const kAlignmentA = MapArgumentsA::kAlignmentA; + + using MapArgumentsB = kernel::detail::Rank2KMapArguments< + typename Mma2::IteratorA::Element, + typename Mma2::IteratorA::Layout, + Mma2::kTransformA, + Mma2::IteratorA::AccessType::kElements, + typename Mma1::IteratorA::Element, + typename Mma1::IteratorA::Layout, + Mma1::kTransformA, + Mma1::IteratorA::AccessType::kElements, + typename Mma2::LayoutC, + FillModeC_, + kTransposed + >; + + using ElementB = typename MapArgumentsB::ElementA; + using LayoutB = typename MapArgumentsB::LayoutA; + static int const kAlignmentB = MapArgumentsB::kAlignmentA; + + // Use the user-provided TransformA and TransformB, rather than those + // resulting from MapArguments, because Mma1 and Mma2 may have different + // complex transforms than those passed in by the user. + // (See kernel/rank_2k_complex.h for an example of this) + static cutlass::ComplexTransform const kTransformA = OriginalTransformA_; + static cutlass::ComplexTransform const kTransformB = OriginalTransformB_; + + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArgumentsA::LayoutC; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + static FillMode const kFillModeC = MapArgumentsA::kFillModeC; + + // Common type definitions for Mma1 and Mma2 + using Operator = typename Mma1::Operator; + using OperatorClass = typename Mma1::Operator::OperatorClass; + using ThreadblockShape = typename Mma1::Shape; + using WarpShape = typename Mma1::Operator::Shape; + using InstructionShape = typename Mma1::Policy::Operator::InstructionShape; + using ArchTag = typename Mma1::ArchTag; + + static int const kStages = Mma1::kStages; + static BlasMode const kBlasMode = BlasMode_; + +private: + static FillMode const kInternalFillModeC = FillModeC_; + +public: + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma1::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = Rank2KGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount, + kInternalFillModeC>; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord *problem_sizes; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params epilogue; + + ElementA ** ptr_A; + ElementB ** ptr_B; + ElementC ** ptr_C; + ElementC ** ptr_D; + + typename LayoutA::Stride::LongIndex *lda; + typename LayoutB::Stride::LongIndex *ldb; + typename LayoutC::Stride::LongIndex *ldc; + typename LayoutC::Stride::LongIndex *ldd; + + // Only used by device-level operator + GemmCoord *host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): + mode(GemmUniversalMode::kGemm), + problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr), + host_problem_sizes(nullptr) + { + + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments( + GemmUniversalMode mode, + GemmCoord *problem_sizes, + int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params epilogue, + ElementA ** ptr_A, + ElementB ** ptr_B, + ElementC ** ptr_C, + ElementC ** ptr_D, + typename LayoutA::Stride::LongIndex *lda, + typename LayoutB::Stride::LongIndex *ldb, + typename LayoutC::Stride::LongIndex *ldc, + typename LayoutC::Stride::LongIndex *ldd, + GemmCoord *host_problem_sizes=nullptr + ): + mode(mode), + problem_sizes(problem_sizes), + problem_count(problem_count), + threadblock_count(threadblock_count), + epilogue(epilogue), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + host_problem_sizes(host_problem_sizes) + { + + } + + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + GemmUniversalMode mode; + int batch_count; + + ElementA ** ptr_A; + ElementB ** ptr_B; + ElementC ** ptr_C; + ElementC ** ptr_D; + + typename LayoutA::Stride::LongIndex *lda; + typename LayoutB::Stride::LongIndex *ldb; + typename LayoutC::Stride::LongIndex *ldc; + typename LayoutC::Stride::LongIndex *ldd; + + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + lda(nullptr), + ldb(nullptr), + ldc(nullptr), + ldd(nullptr) + { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, void *workspace = nullptr, int tile_count = 0): + problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + output_op(args.epilogue), + ptr_A(args.ptr_A), + ptr_B(args.ptr_B), + ptr_C(args.ptr_C), + ptr_D(args.ptr_D), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldd(args.ldd) + { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma1::SharedStorage mma1_main_loop; + typename Mma2::SharedStorage mma2_main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + Rank2KGrouped() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static size_t get_extra_workspace_size( + Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // + // Problem visitor. + // + + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_tile_offset = problem_visitor.threadblock_offset(threadblock_idx); + + // + // Perform checks to determine whether the results of this threadblock will be needed. + // An example of an unneeded threadblock is one that is assigned to compute in the upper + // portion of a Rank2K kernel filled with mode kLower. + // + // TODO: Consider pushing these checks into ProblemVisitor to avoid spuriously + // returning from `next_tile()`. + // + + // Early exit if threadblock is out of range + if (grid_shape.m() <= threadblock_tile_offset.m() || + grid_shape.n() <= threadblock_tile_offset.n()) { + // Next tile + problem_visitor.advance(gridDim.x); + continue; + } + + // Skip this tile if Fill Mode is Lower and + // if the entire tile is above the main diagonal (bottom-left corner is at or above the diagonal) + if (kInternalFillModeC == cutlass::FillMode::kLower && + (threadblock_tile_offset.m() + 1) * Mma1::Shape::kM <= threadblock_tile_offset.n() * Mma1::Shape::kN) { + // Next tile + problem_visitor.advance(gridDim.x); + continue; + } + + // Skip this tile if Fill Mode is Upper and + // if the entire tile is below the main diagonal (top-right corner is at or below the diagonal) + if (kInternalFillModeC == cutlass::FillMode::kUpper && + threadblock_tile_offset.m() * Mma1::Shape::kM >= (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { + // Next tile + problem_visitor.advance(gridDim.x); + continue; + } + + bool tile_on_diagonal = false; + // Mark tiles that are being crossed by the main diagonal + // (top-right and bottom-left corners are on either side of the diagonal) + if ((threadblock_tile_offset.m() + 1) * Mma1::Shape::kM > threadblock_tile_offset.n() * Mma1::Shape::kN + && threadblock_tile_offset.m() * Mma1::Shape::kM < (threadblock_tile_offset.n() + 1) * Mma1::Shape::kN) { + tile_on_diagonal = true; + } + + int offset_k = 0; + int problem_size_k = problem_size.k(); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < grid_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * problem_size.k(); + } + + offset_k = threadblock_tile_offset.k() * problem_size.k(); + } + + ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::Stride::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::Stride::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_MxK{ + threadblock_tile_offset.m() * Mma1::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_KxN{ + offset_k, + threadblock_tile_offset.n() * Mma1::Shape::kN + }; + + // Assume identity swizzle + MatrixCoord tb_offset( + threadblock_tile_offset.m() * Mma1::Shape::kM, + threadblock_tile_offset.n() * Mma1::Shape::kN + ); + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands for Mma1 + typename Mma1::IteratorA iterator_A( + Mma1::IteratorA::Params(ldm_A), + ptr_A, + {problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_MxK); + + typename Mma1::IteratorB iterator_BT( + Mma1::IteratorB::Params(ldm_B), + ptr_B, + {problem_size_k, problem_size.n()}, + thread_idx, + tb_offset_KxN); + + // Construct iterators to A and B operands for Mma2 + typename Mma2::IteratorA iterator_B( + Mma2::IteratorA::Params(ldm_B), + ptr_B, + {problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_MxK); + + typename Mma2::IteratorB iterator_AT( + Mma2::IteratorB::Params(ldm_A), + ptr_A, + {problem_size_k, problem_size.n()}, + thread_idx, + tb_offset_KxN); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply for Mma1 (A x BT) + Mma1 mma1(shared_storage.kernel.mma1_main_loop, thread_idx, warp_idx, lane_idx); + + // Construct thread-scoped matrix multiply for Mma2 (B x AT) + Mma2 mma2(shared_storage.kernel.mma2_main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma1::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma1::Shape::kK - 1) / Mma1::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add (A x BT) + mma1( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_BT, + accumulators); + + // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. + if (kBlasMode == BlasMode::kHermitian) { + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C[problem_idx]); + ElementC *ptr_D = static_cast(params.ptr_D[problem_idx]); + + // If TB not on diagonal, FillMode doesn't apply. + FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone; + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), + ptr_C, + problem_size.mn(), + thread_idx, + tb_offset, + kFillModeTB + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), + ptr_D, + problem_size.mn(), + thread_idx, + tb_offset, + kFillModeTB + ); + + Epilogue epilogue( + shared_storage.kernel.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + + __syncthreads(); + + accumulators.clear(); + } + + // Compute threadblock-scoped matrix multiply-add (B x AT) + mma2( + gemm_k_iterations, + accumulators, + iterator_B, + iterator_AT, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + /* Needed for HER2K where the second HERK is multiplied by conj(alpha) */ + typename EpilogueOutputOp::Params second_her2k_params(conj(params.output_op.alpha), 1); + EpilogueOutputOp output_op_her2k(second_her2k_params); + + // + // Masked tile iterators constructed from members + // + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * grid_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C[problem_idx]); + + // HER2K kernel needs Alpha to be complex and is conj(Alpha) is applied to the second HERK. + if (kBlasMode == BlasMode::kHermitian) { + ptr_C = static_cast(params.ptr_D[problem_idx]); + } + + ElementC *ptr_D = static_cast(params.ptr_D[problem_idx]); + + // If TB not on diagonal, FillMode doesn't apply. + FillMode kFillModeTB = tile_on_diagonal ? kInternalFillModeC : FillMode::kNone; + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), + ptr_C, + problem_size.mn(), + thread_idx, + tb_offset, + kFillModeTB + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), + ptr_D, + problem_size.mn(), + thread_idx, + tb_offset, + kFillModeTB + ); + + Epilogue epilogue( + shared_storage.kernel.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + if (kBlasMode == BlasMode::kSymmetric) { + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + } else { + epilogue( + output_op_her2k, + iterator_D, + accumulators, + iterator_C); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h new file mode 100644 index 00000000..70b00b62 --- /dev/null +++ b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h @@ -0,0 +1,368 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Problem visitor for grouped Rank2K operations. + + This problem visitor is specialized for Rank2K operations, for which matrix C is upper/lower + triangular. Using a problem visitor designed for GEMMs for Rank2K problems is inefficient + because threadblocks will be frequently assigned to tiles that exit early (e.g., due to + being assigned to a tile in the upper-triangular portion of a lower-triangular problem). + This can lead to load imbalance among threadblocks, as the GEMM-based scheduler + assigns all threadblocks to nearly the same number of tiles, regardless of whether + those tiles exit early. + + Consider an example of a group of four Rank2Ks with matrix C consisting of a grid of 2x2 tiles. + Consider a grid of 8 threadblocks. The default GEMM scheduler will assign threadblocks to + tiles in the following order: + Rank2K 0 Rank2K 1 Rank2K 2 Rank2K 3 + 0 1 4 5 0 1 4 5 + 2 3 6 7 2 3 6 7 + Assuming that the problems are lower triangular, blocks 1 and 5 are continuously assigned + to inactive tiles. + + This problem visitor aims to assign threadblocks to only those tiles which are in the + upper/lower triangular portion of a given problem. Using the example above, the resulting + assignment would be: + Rank2K 0 Rank2K 1 Rank2K 2 Rank2K 3 + 0 - 3 - 6 - 1 - + 1 2 4 5 7 0 2 3 + + Achieving the schedule above requires a mapping from threadblock ID to tile coordinates (i, j). + We will illustrate this by mapping on a lower-triangular matrix with a 3x3 grid. We first + calculate row and column indices assuming one-indexed rows, tiles, and threadblock IDs, and + then subtract one to convert to zero-indexed. + Col 1 Col 2 Col 3 + ---------------------- + Row 1 | 1 - - + Row 2 | 2 3 - + Row 3 | 4 5 6 + + We next outline this mapping, borrowing from: https://stackoverflow.com/a/40954159 + + Calculating row i given threadblock ID t + ---------------------------------------- + For a given row i, all threadblock IDs t in that row satisfy the following: + t <= 1 + 2 + 3 + ... + (i-1) + i + + The closed-form equation for the right-hand side is: i(i+1)/2. + Using this, we can solve for i given t: + t <= i(i+1)/2 + 2t <= i^2 + i + 2t <= i^2 + i + 0.25 - 0.25 + 2t + 0.25 <= i^2 + i + 0.25 + 2t + 0.25 <= (i + 0.5)^2 + sqrt(2t + 0.25) - 0.5 <= i + + To account for fractional values, we set: + i = ceil(sqrt(2t + 0.25) - 0.5) + + To turn this into a zero-indexed row and work with zero-indexed t, we perform: + i = ceil(sqrt(2(t+1) + 0.25) - 0.5) - 1 + = ceil(sqrt(2t + 2.25) - 0.5) - 1 + + Calculating column j given threadblock ID t and row i + ----------------------------------------------------- + For a given row i, all threadblock IDs t in that row also satisfy the following: + t > 1 + 2 + 3 + ... + (i-2) + (i-1) + --> t > i(i-1)/2 + + Threadblock IDs within a given row are sequential, so the one-indexed column ID + for one-indexed threadblock ID t and row i is: + j = t - (i(i-1)/2) + + The zero-indexed version becomes: + j = (t+1) - (i(i+1)/2) -1 + = t - (i(i+1)/2) + + Accounting for non-square grids + ------------------------------- + Though the overall output problem size for Rank2K problems is guranteed to be square, the + grids used in computing may not be square due to using non-square threadblock shapes. For + example, a threadblock shape of 64x32 operating on a problem of output size 128x128 would + result in a grid of 2x4 tiles. + + This case can be handled by noting that the output resembles a square grid of 2x2 "macro tiles" + each of which contains 2 "true tiles." We can thus first map a threadblock ID to its "macro tile" + using the equations above, and then map it to the "true tile" within its "macro tile." In the example + of a 2x4 grid, this mapping would look as follows: + "Macro grid" "True grid" + {0, 1} - 0 1 - - + {2, 3} {4, 5} 2 3 4 5 + + A zero-indexed threadblock ID t is mapped to its "macro tile ID" t_macro as: + t_macro = t // r + Where r is the ratio of the maximum dimension of the grid to the minimum dimension of the grid + (i.e., r = 4 / 2 = 2 in the previous example). + + One uses t_macro and the calculations above to find the row and column in the square matrix to + obtain i_macro and j_macro (zero-indexed). The mapping from (i_macro, j_macro) --> (i, j) + is simply the following: + if (ThreadblockShape::M > ThreadblockShape::N): + r = ThreadblockShape::M / ThreadblockShape::N + i = i_macro + j = (j_macro * r) + (t % r) + elif (ThreadblockShape::M < ThreadblockShape::N): + r = ThreadblockShape::N / ThreadblockShape::M + i = (i_macro * r) + (t % r) + j = j_macro + else: + i = i_macro + j = j_macro + + Handling cases with grid dimensions that aren't multiples of eachother + ---------------------------------------------------------------------- + Even though threadblock shapes M and N are typically multiples of one another, the grid + for a given problem may not have dimensions of the same ratio as that of the threadblock. + For example, a problem of size 132x132 using a threadblock of shape 64x32 will result + in a grid of 3x5 tiles. In this case, there is not an integer number of "true tiles" + per "macro tile." + + When this scenario arises, we simply pad the larger dimension of the grid such that + there are an integer number of "true tiles" per "macro tile." Thus, the 3x5 grid in + the example above will be treated as a 3x6 grid. Row and column positions for each + tile are calculated as above. Any threadblocks that map to tiles that are outside the + problem range or upper/lower triangular portion (e.g., (2, 5)) will exit early from + this problem and may proceed to the next problem in the group. + + Handling upper-triangular matrices + ---------------------------------- + The only modification needed for upper-triangular matrices is to swap i_macro and j_macro + in the calculations above. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +namespace detail { +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Helpers for calculating offsets for Rank2K problem visitor. These helpers specifically pertain +// to the conversion from "macro tiles" to "true tiles" in the description above. +// +template < + typename ThreadblockShape, + typename Enable = void +> +struct Rank2KGroupedProblemVisitorOffsetHelper; + +// Partial specialization for the case where threadblock shape M > threadblock shape N +template < + typename ThreadblockShape +> +struct Rank2KGroupedProblemVisitorOffsetHelper< + ThreadblockShape, + typename platform::enable_if< (ThreadblockShape::kM > ThreadblockShape::kN) >::type +> { + static_assert(ThreadblockShape::kM % ThreadblockShape::kN == 0, + "Rank2KGroupedProblemVisitor with threadblock shape M > threadblock shape N " + "requires that threadblock shape M be a multiple of threadblock shape N."); + + static int32_t const kThreadblockSkewRatio = ThreadblockShape::kM / ThreadblockShape::kN; + + CUTLASS_HOST_DEVICE + static int32_t min_dim(cutlass::gemm::GemmCoord grid) { + return grid.m(); + } + + CUTLASS_HOST_DEVICE + static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { + return row; + } + + CUTLASS_HOST_DEVICE + static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { + return (col * kThreadblockSkewRatio) + (threadblock_id % kThreadblockSkewRatio); + } +}; + +// Partial specialization for the case where threadblock shape M < threadblock shape N +template < + typename ThreadblockShape +> +struct Rank2KGroupedProblemVisitorOffsetHelper< + ThreadblockShape, + typename platform::enable_if< (ThreadblockShape::kM < ThreadblockShape::kN) >::type +> { + + static_assert(ThreadblockShape::kN % ThreadblockShape::kM == 0, + "Rank2KGroupedProblemVisitor with threadblock shape M < threadblock shape N " + "requires that threadblock shape N be a multiple of threadblock shape M."); + + static int32_t const kThreadblockSkewRatio = ThreadblockShape::kN / ThreadblockShape::kM; + + CUTLASS_HOST_DEVICE + static int32_t min_dim(cutlass::gemm::GemmCoord grid) { + return grid.n(); + } + + CUTLASS_HOST_DEVICE + static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { + return (row * kThreadblockSkewRatio) + (threadblock_id % kThreadblockSkewRatio); + } + + CUTLASS_HOST_DEVICE + static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { + return col; + } +}; + +// Partial specialization for the case where threadblock shape M == threadblock shape N +// In this case, macro tiles are equivalent to true tiles, so the conversions are +// identity functions. +template < + typename ThreadblockShape +> +struct Rank2KGroupedProblemVisitorOffsetHelper< + ThreadblockShape, + typename platform::enable_if< (ThreadblockShape::kM == ThreadblockShape::kN) >::type +> { + + static int32_t const kThreadblockSkewRatio = 1; + + CUTLASS_HOST_DEVICE + static int32_t min_dim(cutlass::gemm::GemmCoord grid) { + return grid.m(); + } + + CUTLASS_HOST_DEVICE + static int32_t macro_row_to_row(int32_t row, int32_t threadblock_id) { + return row; + } + + CUTLASS_HOST_DEVICE + static int32_t macro_col_to_col(int32_t col, int32_t threadblock_id) { + return col; + } +}; + +// Helper for correctly representing problem sizes in grouped kernels +template +struct Rank2KGroupedProblemSizeHelper { + using OffsetHelper = Rank2KGroupedProblemVisitorOffsetHelper; + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + // Return the number of tiles at or below the diagonal (or at and above + // for mode kUpper). We do this by first calculating this value assuming + // we have a square matrix of tiles of size `dim x dim` where `dim` is the + // minimum among {grid.m(), grid.n()}. We then multiply the resulting value + // by OffsetHelper::kThreadblockSkewRatio to account for cases in which there + // are more tiles in one dimension than the other. + int32_t dim = OffsetHelper::min_dim(grid); + int32_t tiles_on_diagonal = dim; + int32_t tiles_below_diagonal = ((dim * (dim - 1)) / 2); + return (tiles_on_diagonal + tiles_below_diagonal) * OffsetHelper::kThreadblockSkewRatio; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Default problem visitor for fill modes kUpper and kLower. +// +template +struct Rank2KGroupedProblemVisitor : public GroupedProblemVisitor< + detail::Rank2KGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + static cutlass::FillMode const kFillModeC = FillModeC; + + static_assert(kFillModeC == cutlass::FillMode::kLower || kFillModeC == cutlass::FillMode::kUpper, + "Default Rank2KGroupedProblemVisitor requires fill mode of kLower or kUpper."); + + using ProblemSizeHelper = detail::Rank2KGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using OffsetHelper = typename ProblemSizeHelper::OffsetHelper; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + Rank2KGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, shared_storage_, block_idx) + {} + + CUTLASS_DEVICE + cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { + int32_t macro_id = threadblock_id / OffsetHelper::kThreadblockSkewRatio; + int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; + int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); + + if (kFillModeC == cutlass::FillMode::kUpper) { + swap(macro_row, macro_col); + } + + int32_t row = OffsetHelper::macro_row_to_row(macro_row, threadblock_id); + int32_t col = OffsetHelper::macro_col_to_col(macro_col, threadblock_id); + + return cutlass::gemm::GemmCoord(row, col, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h b/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h new file mode 100644 index 00000000..d7ae0bad --- /dev/null +++ b/include/cutlass/gemm/kernel/rank_2k_transpose_operands.h @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Transpositions for Rank2K problems. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + ComplexTransform TransformA, + int AlignmentA, + typename ElementB_, + typename LayoutB_, + ComplexTransform TransformB, + int AlignmentB, + typename LayoutC_, + FillMode FillModeC_, + bool Transpose +> +struct Rank2KMapArguments { + using ElementA = ElementA_; + using LayoutA = LayoutA_; + static ComplexTransform const kTransformA = TransformA; + static int const kAlignmentA = AlignmentA; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + static ComplexTransform const kTransformB = TransformB; + static int const kAlignmentB = AlignmentB; + using LayoutC = LayoutC_; + static FillMode const kFillModeC = FillModeC_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + ComplexTransform TransformA, + int AlignmentA, + typename ElementB_, + typename LayoutB_, + ComplexTransform TransformB, + int AlignmentB, + typename LayoutC_, + FillMode FillModeC_ +> +struct Rank2KMapArguments< + ElementA_, + LayoutA_, + TransformA, + AlignmentA, + ElementB_, + LayoutB_, + TransformB, + AlignmentB, + LayoutC_, + FillModeC_, + true +> { + using ElementA = ElementB_; + using LayoutA = LayoutB_; + static ComplexTransform const kTransformA = TransformB; + static int const kAlignmentA = AlignmentB; + using ElementB = ElementA_; + using LayoutB = LayoutA_; + static ComplexTransform const kTransformB = TransformA; + static int const kAlignmentB = AlignmentA; + using LayoutC = typename layout::LayoutTranspose::type; + static FillMode const kFillModeC = InvertFillMode::mode; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_mma.h b/include/cutlass/gemm/threadblock/default_mma.h index 2fab97d7..9f1de41c 100644 --- a/include/cutlass/gemm/threadblock/default_mma.h +++ b/include/cutlass/gemm/threadblock/default_mma.h @@ -632,8 +632,8 @@ struct DefaultMma::value; - static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value; + static const bool transposeA = platform::is_same< LayoutA, layout::ColumnMajor >::value; + static const bool transposeB = platform::is_same< LayoutB, layout::RowMajor >::value; // Define the MmaCore components using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< diff --git a/include/cutlass/gemm/threadblock/default_mma_core.h b/include/cutlass/gemm/threadblock/default_mma_core.h index 1b67f345..4b843b4e 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core.h +++ b/include/cutlass/gemm/threadblock/default_mma_core.h @@ -47,6 +47,7 @@ #include "cutlass/gemm/threadblock/mma_pipelined.h" #include "cutlass/gemm/threadblock/mma_singlestage.h" #include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/mma.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h index 09abd0ec..8e2dd8d0 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h @@ -269,7 +269,7 @@ struct DefaultMmaCore::value); + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); static int const kWarpThreadArrangementStridedB = kWarpSize / kWarpThreadArrangementContiguousB; diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h index bfa04d44..eb44b113 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h @@ -1002,7 +1002,8 @@ struct DefaultMmaCore< static_assert( platform::is_same::value || - platform::is_same::value, + platform::is_same::value || + platform::is_same::value, "The operator tag must indicate complex multiplication."); // @@ -1075,6 +1076,8 @@ template < typename Shape_, /// Shape of warp-level matrix multiply operator (concept: GemmShape) typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, /// Layout for A operand typename LayoutA_, /// Layout for B operand @@ -1095,7 +1098,7 @@ template < ComplexTransform TransformB_ > struct DefaultMmaCore< - Shape_, WarpShape_, GemmShape<8, 8, 4>, + Shape_, WarpShape_, InstructionShape_, complex, LayoutA_, complex, LayoutB_, complex, LayoutC_, @@ -1109,7 +1112,7 @@ struct DefaultMmaCore< using Shape = Shape_; using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; + using InstructionShape = InstructionShape_; using ElementA = complex; using LayoutA = LayoutA_; using ElementB = complex; @@ -1410,7 +1413,7 @@ struct DefaultMmaCore::value); + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); static int const kWarpThreadArrangementStridedB = kWarpSize / kWarpThreadArrangementContiguousB; diff --git a/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h b/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h index cf4de84f..f9112452 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_with_access_size.h @@ -1,3 +1,40 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data + layout of the global memory fragments, data types, and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting simt instructions. +*/ + #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h b/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h new file mode 100644 index 00000000..90a72444 --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_layernorm_mainloop_fusion.h @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/threadblock/default_mma_core.h" +#include "cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h" +#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" +#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for Scale/Bias vectors + typename ElementScaleBias, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Use zfill or predicate for SM80 out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone + > +struct DefaultMmaLayernormMainloopFusion { + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + /// Define iterators over tiles from scale/bias vectors + using IteratorVarMean = + cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< + cutlass::MatrixShape<1, WarpShape::kN>, + ElementScaleBias, + LayoutScaleBias>; + + /// Define iterators over tiles from scale/bias vectors + using IteratorGammaBeta = + cutlass::transform::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorGammaBeta = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorGammaBeta = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename MmaCore::MmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaLayernormMainloopFusionMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorVarMean, IteratorGammaBeta, SmemIteratorGammaBeta, + CacheOpGammaBeta, + ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, WarpIteratorGammaBeta, Stages, SharedMemoryClear>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h b/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h new file mode 100644 index 00000000..f562035c --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_softmax_mainloop_fusion.h @@ -0,0 +1,160 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined softmax-GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/threadblock/default_mma_core.h" +#include "cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h" +#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" +#include "cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for Scale/Bias vectors + typename ElementScaleBias, + /// Layout type for Scale/Bias vectors + typename LayoutScaleBias, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether problem has been transformed. This determines to which operand + /// the softmax is applied. + bool InternalTranspose, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Use zfill or predicate for SM80 out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone + > +struct DefaultMmaSoftmaxMainloopFusion { + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpGammaBeta = CacheOpA; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + /// Define iterators over tiles from scale/bias vectors + using IteratorNormSum = + cutlass::transform::threadblock::PredicatedScaleBiasVectorIterator< + cutlass::MatrixShape<1, WarpShape::kN>, + ElementScaleBias, + LayoutScaleBias>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaSoftmaxMainloopFusionMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorNormSum, + ElementAccumulator, layout::RowMajor, + typename MmaCore::MmaPolicy, Stages, InternalTranspose, SharedMemoryClear>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h index 8e2d5ce0..0fc68359 100644 --- a/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h +++ b/include/cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h @@ -81,6 +81,8 @@ template < typename Shape_, /// Shape of warp-level matrix multiply operator (concept: GemmShape) typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, /// Layout of accumulator typename LayoutC_, /// Number of stages @@ -96,7 +98,7 @@ template < /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB> struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<8, 8, 4>, + Shape_, WarpShape_, InstructionShape_, complex, layout::ColumnMajor, complex, layout::RowMajor, complex, LayoutC_, @@ -108,7 +110,7 @@ struct DefaultMultistageMmaComplexCore< using Shape = Shape_; using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; + using InstructionShape = InstructionShape_; using ElementA = complex; using LayoutA = layout::ColumnMajor; using ElementB = complex; @@ -210,6 +212,8 @@ template < typename Shape_, /// Shape of warp-level matrix multiply operator (concept: GemmShape) typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, /// Layout of accumulator typename LayoutC_, /// Number of stages @@ -225,7 +229,7 @@ template < /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB> struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<8, 8, 4>, + Shape_, WarpShape_, InstructionShape_, complex, layout::ColumnMajor, complex, layout::ColumnMajor, complex, LayoutC_, @@ -237,7 +241,7 @@ struct DefaultMultistageMmaComplexCore< using Shape = Shape_; using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; + using InstructionShape = InstructionShape_; using ElementA = complex; using LayoutA = layout::ColumnMajor; using ElementB = complex; @@ -339,6 +343,8 @@ template < typename Shape_, /// Shape of warp-level matrix multiply operator (concept: GemmShape) typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, /// Layout of accumulator typename LayoutC_, /// Number of stages @@ -354,7 +360,7 @@ template < /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB> struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<8, 8, 4>, + Shape_, WarpShape_, InstructionShape_, complex, layout::RowMajor, complex, layout::ColumnMajor, complex, LayoutC_, @@ -366,7 +372,7 @@ struct DefaultMultistageMmaComplexCore< using Shape = Shape_; using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; + using InstructionShape = InstructionShape_; using ElementA = complex; using LayoutA = layout::RowMajor; using ElementB = complex; @@ -469,6 +475,8 @@ template < typename Shape_, /// Shape of warp-level matrix multiply operator (concept: GemmShape) typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, /// Layout of accumulator typename LayoutC_, /// Number of stages @@ -484,7 +492,7 @@ template < /// Cache operation of operand B cutlass::arch::CacheOperation::Kind CacheOpB> struct DefaultMultistageMmaComplexCore< - Shape_, WarpShape_, GemmShape<8, 8, 4>, + Shape_, WarpShape_, InstructionShape_, complex, layout::RowMajor, complex, layout::RowMajor, complex, LayoutC_, @@ -496,7 +504,7 @@ struct DefaultMultistageMmaComplexCore< using Shape = Shape_; using WarpShape = WarpShape_; - using InstructionShape = GemmShape<8, 8, 4>; + using InstructionShape = InstructionShape_; using ElementA = complex; using LayoutA = layout::RowMajor; using ElementB = complex; diff --git a/include/cutlass/gemm/threadblock/mma_base.h b/include/cutlass/gemm/threadblock/mma_base.h index 0203c9ce..e2cb4a47 100644 --- a/include/cutlass/gemm/threadblock/mma_base.h +++ b/include/cutlass/gemm/threadblock/mma_base.h @@ -124,6 +124,13 @@ class MmaBase { /// Tensor reference to the B operand using TensorRefB = TensorRef; + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + // // Nested structs // diff --git a/include/cutlass/gemm/threadblock/mma_blas3_multistage.h b/include/cutlass/gemm/threadblock/mma_blas3_multistage.h index d61442ee..08d083ff 100644 --- a/include/cutlass/gemm/threadblock/mma_blas3_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_blas3_multistage.h @@ -139,10 +139,6 @@ public: /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; diff --git a/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h b/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h new file mode 100644 index 00000000..c17d17c8 --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_layernorm_mainloop_fusion_multistage.h @@ -0,0 +1,865 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. + + It loads two loop invariant vectors, mean and var, in the prologue and + stores them in the register file. In the mainloop, it loads two loop + variant vectors, gamma and beta, by using cp.async. We will call + elementwise operation to apply var, mean, gamma, beta between ldmatrix and + warp mma. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/warp/layernorm_scale_bias_transform.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Element type of scale and bias vectors + typename ElementScaleBias_, + /// Layout of scale and bias vectors + typename LayoutScaleBias_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// WarpIterator to load Scale or Bias vector from the shared memory + typename WarpIteratorGammaBeta_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaMainloopFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Element type of scale and bias vectors + using ElementScaleBias = ElementScaleBias_; + + /// Layout of scale and bias vectors + using LayoutScaleBias = LayoutScaleBias_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< WarpIterator to load Scale or Bias vector from the shared memory + using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm::GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the scale and bias vectors + using TensorRefGammaBeta = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the A scale and bias vectors in shared memory + using ShapeGammaBeta = + MatrixShape<1 + Policy::SmemPaddingA::kRow, + 2 * Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for A operand Scale and Bias + AlignedBuffer operand_A_gamma_beta; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a layout object for the A scale and bias vectors + CUTLASS_DEVICE + static LayoutScaleBias LayoutScaleBias() { + return LayoutScaleBias::packed( + {ShapeGammaBeta::kRow, ShapeGammaBeta::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the A operand Scale vector + CUTLASS_HOST_DEVICE + TensorRefGammaBeta operand_A_gamma_beta_ref() { + return TensorRefGammaBeta{operand_A_gamma_beta.data(), LayoutScaleBias()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A operand scale and bias vector + /// from shared memory + WarpIteratorGammaBeta warp_tile_iterator_A_gamma_beta_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMainloopFusionBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_A_gamma_beta_( + shared_storage.operand_A_gamma_beta_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over vectors of var and mean vector in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorVarMean_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorGammaBeta_, + /// Iterates over vectors of scale and bias vector in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorGammaBeta_, + /// Cache operation for scale/bias operand + cutlass::arch::CacheOperation::Kind CacheOpGammaBeta, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// WarpIterator to load Scale or Bias vector from the shared memory + typename WarpIteratorGammaBeta_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaLayernormMainloopFusionMultistage : + public MmaMainloopFusionBase { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of the var and mean vectors in global memory + using IteratorVarMean = IteratorVarMean_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorGammaBeta = IteratorGammaBeta_; + ///< WarpIterator to load Scale or Bias vector from the shared memory + using WarpIteratorGammaBeta = WarpIteratorGammaBeta_; + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Base class + using Base = MmaMainloopFusionBase; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorGammaBeta = SmemIteratorGammaBeta_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + static cutlass::arch::CacheOperation::Kind const kCacheOpGammaBeta = + CacheOpGammaBeta; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + using WarpLoadedFragmentVarMean = typename IteratorVarMean::Fragment; + using WarpLoadedFragmentGammaBeta = + typename WarpIteratorGammaBeta::Fragment; + + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of A operand scale vector to shared memory + SmemIteratorGammaBeta smem_iterator_A_gamma_beta_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + int warp_idx_m_; + + int warp_idx_n_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaLayernormMainloopFusionMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_A_gamma_beta_(shared_storage.operand_A_gamma_beta_ref(), + thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; + warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( + {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorGammaBeta &iterator_A_gamma_beta, + IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + // Async Copy for operand A scale and bias vector. Scale and bias vectors + // are small. One iteration is enough. + if (group_start_A == 0) { + typename IteratorGammaBeta::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_gamma_beta_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorGammaBeta::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over B operand in global memory + IteratorVarMean iterator_var_mean, + ///< iterator over scale and bias vectors in global memory + IteratorGammaBeta iterator_A_gamma_beta, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // + // Prologue + // + // Issue several complete stages + + WarpLoadedFragmentVarMean warp_loaded_frag_var_mean; + iterator_var_mean.add_tile_offset({0, warp_idx_m_}); + iterator_var_mean.load(warp_loaded_frag_var_mean); + + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + // Async Copy for operand A scale and bias vectors. Scale and bias + // vectors are small. One iteration is enough. + { + typename IteratorGammaBeta::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_gamma_beta_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorGammaBeta::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_A_gamma_beta.get(), iterator_A_gamma_beta.valid()); + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_A_gamma_beta.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpLoadedFragmentGammaBeta warp_loaded_frag_A_gamma_beta[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + cutlass::gemm::warp::LayernormScaleBiasTransform + elementwise_transform; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_A_gamma_beta_.load( + warp_loaded_frag_A_gamma_beta[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_gamma_beta_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + elementwise_transform(warp_transformed_frag_A[0], + warp_loaded_frag_var_mean, + warp_loaded_frag_A_gamma_beta[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_gamma_beta_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_gamma_beta_.load( + warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_gamma_beta_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_loaded_frag_var_mean, + warp_loaded_frag_A_gamma_beta[warp_mma_k % 2]); + } + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_A_gamma_beta, iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_A_gamma_beta.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_A_gamma_beta_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_A_gamma_beta_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_A_gamma_beta_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_A_gamma_beta.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + elementwise_transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_var_mean, + warp_loaded_frag_A_gamma_beta[(warp_mma_k + 1) % 2]); + } + } + + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_multistage.h b/include/cutlass/gemm/threadblock/mma_multistage.h index a3041ea3..d920e3b5 100644 --- a/include/cutlass/gemm/threadblock/mma_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_multistage.h @@ -133,10 +133,6 @@ public: /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -429,7 +425,7 @@ public: } } - // Waits until kStages-2 stages have committed. + // Waits until stages up to the previous (kStages-2)th stage have committed. cutlass::arch::cp_async_wait(); __syncthreads(); @@ -558,7 +554,7 @@ public: // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); - // Waits until kStages-2 stages have committed. + // Waits until stages up to the previous (kStages-2)th stage have committed. arch::cp_async_wait(); __syncthreads(); diff --git a/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h b/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h new file mode 100644 index 00000000..b74b2b17 --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_softmax_mainloop_fusion_multistage.h @@ -0,0 +1,751 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. + + It loads two loop invariant vectors, norm and sum, in the prologue and + stores them in the register file. We will call elementwise operation to + apply norm and sum between ldmatrix and warp mma. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/warp/softmax_scale_bias_transform.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class MmaMainloopFusionBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm::GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMainloopFusionBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over vectors of var and mean vector in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorNormSum_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Whether problem has been transformed. This determines to which operand + /// the softmax is applied. + bool InternalTranspose, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaSoftmaxMainloopFusionMultistage : + public MmaMainloopFusionBase { +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of the var and mean vectors in global memory + using IteratorNormSum = IteratorNormSum_; + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Base class + using Base = MmaMainloopFusionBase; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + using WarpLoadedFragmentNormSum = typename IteratorNormSum::Fragment; + + static bool const kInternalTranspose = InternalTranspose; + + using SoftmaxFragment = typename platform::conditional::type; + + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + int warp_idx_m_; + + int warp_idx_n_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + MmaSoftmaxMainloopFusionMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + warp_idx_m_ = warp_idx_mn % Base::WarpCount::kM; + warp_idx_n_ = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m_, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n_}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorB &iterator_B, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over B operand in global memory + IteratorNormSum iterator_norm_sum, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // + // Prologue + // + // Issue several complete stages + + WarpLoadedFragmentNormSum warp_loaded_frag_norm_sum; + iterator_norm_sum.add_tile_offset({0, warp_idx_m_}); + iterator_norm_sum.load(warp_loaded_frag_norm_sum); + + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + cutlass::gemm::warp::SoftmaxScaleBiasTransform< + SoftmaxFragment, WarpLoadedFragmentNormSum> elementwise_transform; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance(iterator_A, iterator_B); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + if (kInternalTranspose) { + elementwise_transform(warp_transformed_frag_B[0], + warp_loaded_frag_norm_sum); + } else { + elementwise_transform(warp_transformed_frag_A[0], + warp_loaded_frag_norm_sum); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (kInternalTranspose) { + elementwise_transform(warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_norm_sum); + } else { + elementwise_transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_loaded_frag_norm_sum); + } + } + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + } else { + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + } + + copy_tiles_and_advance(iterator_A, iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + if (kInternalTranspose) { + elementwise_transform(warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_norm_sum); + } else { + elementwise_transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_norm_sum); + } + } + } + + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_sparse_base.h b/include/cutlass/gemm/threadblock/mma_sparse_base.h index 91790918..42bbcfe2 100644 --- a/include/cutlass/gemm/threadblock/mma_sparse_base.h +++ b/include/cutlass/gemm/threadblock/mma_sparse_base.h @@ -120,6 +120,13 @@ class SparseMmaBase { static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + /// Number of stages static int const kStages = Stages; diff --git a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h index a839e218..39f75233 100644 --- a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h @@ -156,10 +156,6 @@ public: /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of async copies to load one stage of operand A static int const TBLDGSTSIterationsA = IteratorA::ThreadMap::Iterations::kCount; diff --git a/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h index 2bd7f765..e8bc4488 100644 --- a/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_with_reduction_multistage.h @@ -137,10 +137,6 @@ public: /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; diff --git a/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h b/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h new file mode 100644 index 00000000..08fc93b3 --- /dev/null +++ b/include/cutlass/gemm/warp/layernorm_scale_bias_transform.h @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level per channel scale+bias+relu before + matrix multiply-accumulate operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LayernormScaleBiasTransform { + + using T = typename FragmentActivations::Element; + + static int const NumActivations = FragmentActivations::kElements; + static int const NumVarMean = FragmentVarMean::kElements; + static int const NumGammaBeta = FragmentGammaBeta::kElements; + static int const MmaElements = 2; + // One element has one scale and one bias + static int const MmaScaleBiasPair = 2; + // 16816 has 2 columns and 2 rows + static int const MmaCols = 2; + static int const MmaRows = 2; + + using MmaOperand = Array; + using VarMeanOperand = Array<__half2, MmaScaleBiasPair>; + using GammaBetaOperand = Array; + + CUTLASS_DEVICE + void transform(MmaOperand &activations, + VarMeanOperand const &var_mean, + GammaBetaOperand const &gamma_beta) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t *ptr_activations = reinterpret_cast(&activations); + uint32_t const *ptr_var_mean = reinterpret_cast(&var_mean); + uint32_t const *ptr_gamma_beta = reinterpret_cast(&gamma_beta); + + // Apply per channel scale+bias+relu if the data is not a special NaN + // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. + + // We assumes the pair of FP16 are either both inbound or both out-of-bound. + // It requires C to be an even number. + asm volatile( + "{\n\t" + " fma.rn.f16x2 %0, %1, %2, %3;\n" + " fma.rn.f16x2 %0, %4, %0, %5;\n" + "}\n" + : "=r"(ptr_activations[0]) + : "r"(ptr_var_mean[0]), "r"(ptr_activations[0]), + "r"(ptr_var_mean[1]), + "r"(ptr_gamma_beta[0]), "r"(ptr_gamma_beta[1])); +#else + // TODO: write emulation code + assert(0); +#endif + } + + CUTLASS_DEVICE + void operator()(FragmentActivations &activations, + FragmentVarMean const &var_mean, + FragmentGammaBeta const &gamma_beta) { + MmaOperand *ptr_activations = reinterpret_cast(&activations); + VarMeanOperand const *ptr_var_mean = + reinterpret_cast(&var_mean); + GammaBetaOperand const *ptr_gamma_beta = + reinterpret_cast(&gamma_beta); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < (NumActivations / MmaElements); ++i) { + transform(ptr_activations[i], + ptr_var_mean[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows], + ptr_gamma_beta[(i / MmaScaleBiasPair) % MmaCols]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_complex_tensor_op.h index c524e1d5..5054ddaf 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op.h @@ -46,7 +46,6 @@ #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" - #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/warp/mma.h" @@ -251,6 +250,8 @@ template < ComplexTransform TransformA = ComplexTransform::kNone, /// Complex transform on B operand ComplexTransform TransformB = ComplexTransform::kNone, + /// Do source operands need more than one elements + bool GeneralizedOperatorElements = false, /// Used for partial specialization typename Enable = bool > @@ -279,9 +280,7 @@ template < /// Complex transform on A operand ComplexTransform TransformA, /// Complex transform on B operand - ComplexTransform TransformB, - /// Used for partial specialization - typename Enable + ComplexTransform TransformB > class MmaComplexTensorOp< Shape_, @@ -293,8 +292,7 @@ class MmaComplexTensorOp< LayoutC_, Policy_, TransformA, - TransformB, - Enable> { + TransformB> { public: /// Shape of warp-level matrix operation (concept: GemmShape) using Shape = Shape_; @@ -565,9 +563,7 @@ template < /// Complex transform on A operand ComplexTransform TransformA, /// Complex transform on B operand - ComplexTransform TransformB, - /// Used for partial specialization - typename Enable + ComplexTransform TransformB > class MmaComplexTensorOp< Shape_, @@ -579,8 +575,7 @@ class MmaComplexTensorOp< LayoutC_, Policy_, TransformA, - TransformB, - Enable> { + TransformB> { public: /// Shape of warp-level matrix operation (concept: GemmShape) using Shape = Shape_; diff --git a/include/cutlass/gemm/warp/mma_simt_tile_iterator.h b/include/cutlass/gemm/warp/mma_simt_tile_iterator.h index 668c309d..f78a1339 100644 --- a/include/cutlass/gemm/warp/mma_simt_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_simt_tile_iterator.h @@ -618,7 +618,7 @@ public: /// Fragment object holding a thread's part of a tile using Fragment = Array; -private: +protected: /// Internal reference cutlass::TensorRef, layout::RowMajor> ref_; diff --git a/include/cutlass/gemm/warp/mma_tensor_op.h b/include/cutlass/gemm/warp/mma_tensor_op.h index 2c62cd07..1b6570ac 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_tensor_op.h @@ -295,6 +295,14 @@ public: #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + CUTLASS_PRAGMA_UNROLL for (int n = 0; n < MmaIterations::kColumn; ++n) { @@ -320,6 +328,14 @@ public: } #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + CUTLASS_PRAGMA_UNROLL for (int m = 0; m < MmaIterations::kRow; ++m) { diff --git a/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h index 06a7a994..420a8a50 100644 --- a/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_with_reduction_tensor_op.h @@ -146,6 +146,21 @@ public: static bool const kReduceKForA = ReduceKForA_; + static_assert(platform::is_same::value || + platform::is_same::value, + "ElementA needs to be fp16 or bf16."); + + static_assert(platform::is_same::value || + platform::is_same::value, + "ElementB needs to be fp16 or bf16."); + + static_assert(platform::is_same>::value, + "Only supports 16x8x16 tensor core instruction."); + + static_assert(!AccumulatorsInRowMajor, + "Only calls tensor core instructions in column major."); + public: /// Iterates over the A operand in memory @@ -226,30 +241,7 @@ public: MmaOperandC *ptr_D = reinterpret_cast(&D); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n + m_serpentine * MmaIterations::kColumn], - ptr_A[m_serpentine], - ptr_B[n], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } else { - mma( - ptr_D[m_serpentine + n * MmaIterations::kRow], - ptr_A[m_serpentine], - ptr_B[n], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } - } + assert(0); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) // Serpentine visitation order maximizing reuse of Ra CUTLASS_PRAGMA_UNROLL @@ -260,25 +252,21 @@ public: int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - if (AccumulatorsInRowMajor) { // matrix B is reordered - mma( - ptr_D[n_serpentine + m * MmaIterations::kColumn], + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } else { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], - ptr_A[m], - ptr_B[n_serpentine], - ptr_D[m + n_serpentine * MmaIterations::kRow]); + ptr_D[m + n_serpentine * MmaIterations::kRow]); - if (!kReduceKForA && m == 0) { -// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4]); -// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 1]); -// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 2]); -// gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 3]); + if (!kReduceKForA && m == 0) { + #if 0 + gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4]); + gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 1]); + gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 2]); + gemm_k_reduction[n_serpentine] += float(B[n_serpentine * 4 + 3]); + #else + uint32_t const *tmp = reinterpret_cast(&B); - uint32_t const *tmp = reinterpret_cast(&B); + if (platform::is_same::value) { asm volatile( "{\n\t" " .reg .f16 low, high;\n\t" @@ -296,48 +284,99 @@ public: "}\n\t" : "+f"(gemm_k_reduction[n_serpentine]) : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); + } else if (platform::is_same::value) { + asm volatile( + "{\n\t" + " .reg .f32 tmp;\n\t" + " shl.b32 tmp, %1, 16;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " and.b32 tmp, %1, 0xffff0000;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " shl.b32 tmp, %2, 16;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " and.b32 tmp, %2, 0xffff0000;\n\t" + " add.f32 %0, tmp, %0;\n\t" + "}\n\t" + : "+f"(gemm_k_reduction[n_serpentine]) + : "r"(tmp[n_serpentine * 2]), "r"(tmp[n_serpentine * 2 + 1])); + } else { + assert(0); } + #endif } if (kReduceKForA && (n == 0)) { -// gemm_k_reduction[m * 2] += float(A[m * 8]); -// gemm_k_reduction[m * 2] += float(A[m * 8 + 1]); -// gemm_k_reduction[m * 2] += float(A[m * 8 + 4]); -// gemm_k_reduction[m * 2] += float(A[m * 8 + 5]); -// -// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 2]); -// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 3]); -// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 6]); -// gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 7]); - + #if 0 + gemm_k_reduction[m * 2] += float(A[m * 8]); + gemm_k_reduction[m * 2] += float(A[m * 8 + 1]); + gemm_k_reduction[m * 2] += float(A[m * 8 + 4]); + gemm_k_reduction[m * 2] += float(A[m * 8 + 5]); + + gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 2]); + gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 3]); + gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 6]); + gemm_k_reduction[m * 2 + 1] += float(A[m * 8 + 7]); + #else uint32_t const *tmp = reinterpret_cast(&A); - asm volatile( - "{\n\t" - " .reg .f16 low, high;\n\t" - " .reg .f32 tmp;\n\t" - " mov.b32 {low, high}, %2;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " mov.b32 {low, high}, %3;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " mov.b32 {low, high}, %4;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %0, tmp, %0;\n\t" - " mov.b32 {low, high}, %5;\n\t" - " cvt.f32.f16 tmp, low;\n\t" - " add.f32 %1, tmp, %1;\n\t" - " cvt.f32.f16 tmp, high;\n\t" - " add.f32 %1, tmp, %1;\n\t" - "}\n\t" - : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) - : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); + + if (platform::is_same::value) { + asm volatile( + "{\n\t" + " .reg .f16 low, high;\n\t" + " .reg .f32 tmp;\n\t" + " mov.b32 {low, high}, %2;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " mov.b32 {low, high}, %3;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " mov.b32 {low, high}, %4;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " mov.b32 {low, high}, %5;\n\t" + " cvt.f32.f16 tmp, low;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " cvt.f32.f16 tmp, high;\n\t" + " add.f32 %1, tmp, %1;\n\t" + "}\n\t" + : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) + : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); + + } else if (platform::is_same::value) { + + asm volatile( + "{\n\t" + " .reg .f32 tmp;\n\t" + " shl.b32 tmp, %2, 16;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " and.b32 tmp, %2, 0xffff0000;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " shl.b32 tmp, %3, 16;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " and.b32 tmp, %3, 0xffff0000;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " shl.b32 tmp, %4, 16;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " and.b32 tmp, %4, 0xffff0000;\n\t" + " add.f32 %0, tmp, %0;\n\t" + " shl.b32 tmp, %5, 16;\n\t" + " add.f32 %1, tmp, %1;\n\t" + " and.b32 tmp, %5, 0xffff0000;\n\t" + " add.f32 %1, tmp, %1;\n\t" + "}\n\t" + : "+f"(gemm_k_reduction[m * 2]), "+f"(gemm_k_reduction[m * 2 + 1]) + : "r"(tmp[m * 4]), "r"(tmp[m * 4 + 1]),"r"(tmp[m * 4 + 2]), "r"(tmp[m * 4 + 3])); + + } else { + assert(0); + } + #endif } } } diff --git a/include/cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h b/include/cutlass/gemm/warp/scale_bias_tile_iterator.h similarity index 94% rename from include/cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h rename to include/cutlass/gemm/warp/scale_bias_tile_iterator.h index 85b8dde2..625ce3e6 100644 --- a/include/cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h +++ b/include/cutlass/gemm/warp/scale_bias_tile_iterator.h @@ -57,7 +57,7 @@ //////////////////////////////////////////////////////////////////////////////// namespace cutlass { -namespace conv { +namespace gemm { namespace warp { //////////////////////////////////////////////////////////////////////////////// @@ -77,7 +77,7 @@ template < int Threads, /// Number of partitions along K dimension int PartitionsK_ = 1> -class WarpIteratorScaleBias; +class ScaleBiasTileIterator; //////////////////////////////////////////////////////////////////////////////// @@ -99,7 +99,7 @@ template < typename Policy_, /// Number of partitions along K dimension int PartitionsK_> -class WarpIteratorScaleBias { public: /// Shape of tile to load (concept: PitchLinearShape) @@ -167,14 +167,14 @@ class WarpIteratorScaleBias::value / 8; return *this; @@ -194,7 +194,7 @@ class WarpIteratorScaleBias::value * kElementsPerAccess / 8; @@ -230,12 +230,12 @@ class WarpIteratorScaleBias -class WarpIteratorScaleBias { public: /// Shape of tile to load (concept: PitchLinearShape) @@ -400,7 +400,7 @@ class WarpIteratorScaleBias, Element, layout::PitchLinear, layout::PitchLinearShape +struct SoftmaxScaleBiasTransform { + + using T = typename FragmentActivations::Element; + + static int const NumActivations = FragmentActivations::kElements; + static int const NumNormSum = FragmentNormSum::kElements; + static int const MmaElements = 2; + // One element has one scale and one bias + static int const MmaScaleBiasPair = 2; + // 16816 has 2 columns and 2 rows + static int const MmaCols = 2; + static int const MmaRows = 2; + + using MmaOperand = Array; + using NormSumOperand = Array<__half2, MmaScaleBiasPair>; + + CUTLASS_DEVICE + void transform(MmaOperand &activations, + NormSumOperand const &norm_sum) { + + __half2* packed_activations = reinterpret_cast<__half2*>(&activations); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MmaElements / 2; ++i) { + __half2 out = ::h2exp(__hsub2(packed_activations[i], norm_sum[2*i])); + packed_activations[i] = __hmul2(out, norm_sum[2*i + 1]); + } + } + + CUTLASS_DEVICE + void operator()(FragmentActivations &activations, + FragmentNormSum const &norm_sum) { + MmaOperand *ptr_activations = reinterpret_cast(&activations); + NormSumOperand const *ptr_norm_sum = + reinterpret_cast(&norm_sum); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < (NumActivations / MmaElements); ++i) { + transform(ptr_activations[i], + ptr_norm_sum[i / (MmaCols * MmaRows) * MmaRows + i % MmaRows]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h new file mode 100644 index 00000000..6a0b2170 --- /dev/null +++ b/include/cutlass/layout/permute.h @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines layout functions used by GEMM+permute path for common tensor or matrix formats. + + Like Layout functions, permute layout functions map logical coordinates to linear memory. They often require additional + data to describe strides between elements. + + Permute layout functions must implement all members in the interface of NoPermute<> defined in this file. Address offset + computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_permute_} as new addresses after permute op. +*/ +#pragma once +#if defined(__CUDACC_RTC__) +#include +#else +#include "assert.h" +#endif +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/coord.h" +#include "cutlass/tensor_coord.h" + +namespace cutlass { +namespace layout { + +class NoPermute { +public: + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + +private: + // + // Data members + // + + MatrixCoord extent_; + + Index stride_unit_; // sizeof(AccessType) / kElementsPerAccess in epilogue's predicated_tile_iterator + + Index stride_permute_; + + Index col_permute_; + Index row_permute_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + NoPermute() { } + + /// Constructor + CUTLASS_HOST_DEVICE + NoPermute(MatrixCoord extent, Index stride_init): extent_(extent) { } + + /// Computes the address offset after Permute Op in Bytes + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord offset_init) { return 0; } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Defines permute layouts of various tensor formats. +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Permute layout function for 4-D permuted tensors with output matrix (dimension as [M, N]) reshaped +/// as [M/D1, D1, D2, N/D2]. Then perform permute([0, 2, 1, 3]) on the corresponding output tensor. +template +class Tensor4DPermute0213 { +public: + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + +private: + // + // Data members + // + + MatrixCoord extent_; + + Index stride_permute_; + + Index col_permute_; + Index row_permute_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermute0213() { } + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermute0213(MatrixCoord extent, Index stride_init): extent_(extent) { + + /// Update stride_permute with stride_init + stride_permute_ = stride_init / D2 * D1; // stride in Elements + + } + + /// Computes the address offset after Permute Op in Bytes + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord offset_init) { + // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X + // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. + assert(extent_.row() % D1 == 0); + assert(extent_.column() % D2 == 0); + + int D3 = extent_.column() / D2; + + Index col_init = offset_init.column(); + Index row_init = offset_init.row(); + + int l = col_init % D3; + int k = col_init / D3; + int j = row_init % D1; + int i = row_init / D1; + + // After the Permute Op + col_permute_ = l + j * D3; + row_permute_ = k + i * D2; + + return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_); + } + + /// Return D1 + CUTLASS_HOST_DEVICE + Index d1() const { + return D1; + } + + /// Return D2 + CUTLASS_HOST_DEVICE + Index d2() const { + return D2; + } +}; + +/// Permute layout function for 4-D permuted tensors for BMM with BMM output tensor (dimension as [B, M, N]) reshaped +/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM output tensor. +template +class Tensor4DPermuteBMM0213 { +public: + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + +private: + // + // Data members + // + + MatrixCoord extent_; + + Index stride_permute_; + + Index col_permute_; + Index row_permute_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermuteBMM0213() { } + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor4DPermuteBMM0213(MatrixCoord extent, Index stride_init): extent_(extent) { + + /// Update stride_permute with stride_init + stride_permute_ = stride_init * D1; // stride in Elements + + } + + /// Computes the address offset after Permute Op in Bytes + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord offset_init) { + + // The batch index for BMM + Index BMM_batch_idx = blockIdx.z; + + // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X + // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. + int D2 = extent_.row(); + int D3 = extent_.column(); + + Index col_init = offset_init.column(); + Index row_init = offset_init.row(); + + int l = col_init; + int k = row_init; + int j = BMM_batch_idx % D1; + int i = BMM_batch_idx / D1; + + // After the Permute Op + col_permute_ = l + j * D3; + row_permute_ = k + i * D2; + + return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_); + } + + /// Return D1 + CUTLASS_HOST_DEVICE + Index d1() const { + return D1; + } +}; + +/// Permute layout function for 5-D permuted tensors with output matrix (dimension as [M, N]) reshaped +/// as [M/T1, T1, T2, T3, N/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding output tensor. +template +class Tensor5DPermute20314 { +public: + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + +private: + // + // Data members + // + + MatrixCoord extent_; + + Index stride_permute_; + + Index col_permute_; + Index row_permute_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor5DPermute20314() { } + + /// Constructor + CUTLASS_HOST_DEVICE + Tensor5DPermute20314(MatrixCoord extent, Index stride_init): extent_(extent) { + + /// Update stride_permute with stride_init + stride_permute_ = stride_init / T2 * T1; // stride in Elements + + } + + /// Computes the address offset after Permute Op in Bytes + CUTLASS_HOST_DEVICE + LongIndex operator()(MatrixCoord offset_init) { + + // Permute as torch.permute(X1, [2, 0, 3, 1, 4]) -> 5D Tensor indices as [i,j,k,l,m], the dimension of X + // is [T0, T1, T2, T3, T4], after permutation the dim of X1 is [T2, T0, T3, T1, T4]. + int T0 = extent_.row() / T1; + int T4 = extent_.column() / T2 / T3; + + Index col_init = offset_init.column(); + Index row_init = offset_init.row(); + + int m = col_init % T4; + int l = int(col_init / T4) % T3; + int k = int(col_init / T4) / T3; + int j = row_init % T1; + int i = row_init / T1; + + // After the Permute Op + col_permute_ = m + j * T4 + l * T1 * T4; + row_permute_ = i + k * T0; + + return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layout +} // namespace cutlass diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm75.h b/include/cutlass/layout/tensor_op_multiplicand_sm75.h index 4009ac84..ce57b1ce 100644 --- a/include/cutlass/layout/tensor_op_multiplicand_sm75.h +++ b/include/cutlass/layout/tensor_op_multiplicand_sm75.h @@ -48,6 +48,8 @@ namespace layout { /// Template based on element size (in bits) - defined in terms of pitch-linear /// memory and Crosswise size (in elements). +/// This one is the base class of all Ampere/Turing fp16/bf16/int8/int4/int1 +/// tensor core kernels. tf32 TN uses this too. template struct TensorOpMultiplicand { /// Logical rank of tensor @@ -321,6 +323,7 @@ struct TensorOpMultiplicandCongruous { /// Template based on element size (in bits) - defined in terms of pitch-linear /// memory and Crosswise size (in elements). +/// This one is just for TF32 NT kernel. template struct TensorOpMultiplicandCongruous<32, Crosswise> { /// Logical rank of tensor diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm80.h b/include/cutlass/layout/tensor_op_multiplicand_sm80.h index 33602879..3a79bf1f 100644 --- a/include/cutlass/layout/tensor_op_multiplicand_sm80.h +++ b/include/cutlass/layout/tensor_op_multiplicand_sm80.h @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief + \brief layouts needed by Ampere fp64 tensor core kernels. */ #pragma once diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index d7aad5ab..ad628e8a 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -185,7 +185,8 @@ struct NumericConverter { CUTLASS_DEVICE static result_type convert(source_type const & s) { - int32_t intermediate = __float2int_rn(s); + int32_t intermediate; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); return static_cast(intermediate); } @@ -206,7 +207,8 @@ struct NumericConverter { CUTLASS_DEVICE static result_type convert(source_type const & s) { - int32_t intermediate = __float2int_rz(s); + int32_t intermediate; + asm volatile("cvt.rzi.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); return static_cast(intermediate); } @@ -228,7 +230,14 @@ struct NumericConverter { static result_type convert(source_type const & s) { std::fesetround(FE_TONEAREST); - int32_t intermediate = (result_type)std::nearbyint(s); + int32_t intermediate = (int32_t)std::nearbyint(s); + + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + return static_cast(intermediate); } @@ -246,7 +255,14 @@ struct NumericConverter { static result_type convert(source_type const & s) { std::fesetround(FE_TOWARDZERO); - int32_t intermediate = (result_type)std::nearbyint(s); + int32_t intermediate = (int32_t)std::nearbyint(s); + + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + return static_cast(intermediate); } diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index ff6e3db6..6b8a626f 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -55,7 +55,7 @@ * (2) Re-implementations of STL functions and types: * - C++ features that need the \p __device__ annotation. These are * placed into the \p platform namespace. - * - \p abs + * - \p abs * - \p plus * - \p less * - \p greater @@ -452,6 +452,7 @@ struct is_base_of typename remove_cv::type>::value) || (is_same::type, typename remove_cv::type>::value)> {}; + #else using std::is_same; @@ -842,7 +843,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr uint32_t lowest() noexcept { return 0;} CUTLASS_HOST_DEVICE - static constexpr uint32_t max() noexcept { return 4294967295;} + static constexpr uint32_t max() noexcept { return 4294967295U;} static constexpr bool is_integer = true; }; @@ -851,7 +852,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr uint16_t lowest() noexcept { return 0;} CUTLASS_HOST_DEVICE - static constexpr uint16_t max() noexcept { return 65535;} + static constexpr uint16_t max() noexcept { return 65535U;} static constexpr bool is_integer = true; }; @@ -860,7 +861,7 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static constexpr uint8_t lowest() noexcept { return 0;} CUTLASS_HOST_DEVICE - static constexpr uint8_t max() noexcept { return 255;} + static constexpr uint8_t max() noexcept { return 255U;} static constexpr bool is_integer = true; }; diff --git a/include/cutlass/reduction/kernel/reduce_softmax_final.h b/include/cutlass/reduction/kernel/reduce_softmax_final.h new file mode 100644 index 00000000..e734c292 --- /dev/null +++ b/include/cutlass/reduction/kernel/reduce_softmax_final.h @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a final reduction for softmax +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +template < + typename ElementNorm_, + typename ElementSum_, + typename ElementSoftmaxCompute_, + typename ThreadblockShape_, + bool GroupedProblem = false +> +class ApplySoftmaxFinalReduction { +public: + + using ElementNorm = ElementNorm_; + using ElementSum = ElementSum_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; + using ThreadblockShape = ThreadblockShape_; + static const bool isGroupedProblem = GroupedProblem; + + // + // Arguments + // + + struct Arguments { + + cutlass::gemm::GemmCoord* problem_sizes; + cutlass::gemm::GemmCoord problem_size; + ElementNorm* block_Norm; + ElementSum* block_Sum; + int64_t* offset_Norm_Device; + int64_t* offset_Sum_Device; + int64_t batch_stride_Max; + int64_t batch_stride_Sum; + + // + // Methods + // + Arguments() { } + + // Non-grouped constructor without batching + Arguments( + cutlass::gemm::GemmCoord problem_size, + ElementNorm* block_Norm, + ElementSum* block_Sum + ): + problem_size(problem_size), + block_Norm(block_Norm), + block_Sum(block_Sum), + problem_sizes(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr), + batch_stride_Max(0), + batch_stride_Sum(0) + { + + } + + // Non-grouped constructor with batching + Arguments( + cutlass::gemm::GemmCoord problem_size, + ElementNorm* block_Norm, + ElementSum* block_Sum, + int64_t batch_stride_Max, + int64_t batch_stride_Sum + ): + problem_size(problem_size), + block_Norm(block_Norm), + block_Sum(block_Sum), + batch_stride_Max(batch_stride_Max), + batch_stride_Sum(batch_stride_Sum), + problem_sizes(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr) + { + + } + + + // Grouped constructor + Arguments( + cutlass::gemm::GemmCoord *problem_sizes, + ElementNorm* block_Norm, + ElementSum* block_Sum, + int64_t* offset_Norm_Device, + int64_t* offset_Sum_Device + ): + problem_sizes(problem_sizes), + problem_size(cutlass::gemm::GemmCoord(0, 0, 0)), + block_Norm(block_Norm), + block_Sum(block_Sum), + offset_Norm_Device(offset_Norm_Device), + offset_Sum_Device(offset_Sum_Device) + { + + } + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + ApplySoftmaxFinalReduction() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + apply(params, shared_storage); + } + +private: + + /// Full reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + int tid = threadIdx.x; + int bid = blockIdx.x; + int bdim = blockDim.x; + + int block_batch = blockIdx.z; + + // defining three vars for a general reduction module + cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; + int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; + int access_offset = isGroupedProblem ? 0 : bid * bdim; + + if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; + + ElementNorm *curr_ptr_Max = isGroupedProblem ? \ + params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ + params.args.block_Norm + block_batch * params.args.batch_stride_Max; + ElementSum *curr_ptr_Sum = isGroupedProblem ? \ + params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ + params.args.block_Sum + block_batch * params.args.batch_stride_Sum; + + int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + + using ConvertSumOutput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + using ConvertSum = cutlass::NumericConverter; + using ConvertNorm = cutlass::NumericConverter; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + + ConvertSumOutput convert_sum_output; + ConvertNormOutput convert_norm_output; + + uint32_t float_max_bits = 0xff7fffff; + float min_float = reinterpret_cast(float_max_bits); + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { + ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; + ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; + ElementNorm *access_n_bak = access_n; + ElementSum *access_s_bak = access_s; + ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); + ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); + ElementNorm fetch_n; + ElementSum fetch_s; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); + access_n += problem_size.m(); + } + + access_n = access_n_bak; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + cutlass::arch::global_load(fetch_s, access_s, true); + sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); + access_n += problem_size.m(); + access_s += problem_size.m(); + } + + ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; + + access_n = access_n_bak; + access_s = access_s_bak; + + access_n[0] = convert_norm_output(max_val); + access_s[0] = convert_sum_output(inv_sum); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/include/cutlass/semaphore.h b/include/cutlass/semaphore.h index 5765f48f..48f5b01a 100644 --- a/include/cutlass/semaphore.h +++ b/include/cutlass/semaphore.h @@ -90,17 +90,19 @@ public: /// Waits until the semaphore is equal to the given value CUTLASS_DEVICE void wait(int status = 0) { - +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) while( __syncthreads_and(state != status) ) { fetch(); } __syncthreads(); +#endif } /// Updates the lock with the given result CUTLASS_DEVICE void release(int status = 0) { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) || defined(__CUDACC_RTC__) __syncthreads(); if (wait_thread) { @@ -110,6 +112,7 @@ public: asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); #endif } +#endif } }; diff --git a/include/cutlass/transform/thread/unaryOp.h b/include/cutlass/transform/thread/unaryOp.h index 7696cf73..1ad2225c 100644 --- a/include/cutlass/transform/thread/unaryOp.h +++ b/include/cutlass/transform/thread/unaryOp.h @@ -59,16 +59,16 @@ class UnaryOp "Unary Operator not supported."); FragmentOut out; - if( platform::is_same::value ) + if (platform::is_same::value ) { CUTLASS_PRAGMA_UNROLL - for(int i=0; i < FragmentIn::kElements; ++i){ + for (int i=0; i < FragmentIn::kElements; ++i){ out[i] = static_cast(in[i]); } } - else if( platform::is_same::value ) + else if (platform::is_same::value ) { - for(int i=0; i < FragmentIn::kElements; ++i){ + for (int i=0; i < FragmentIn::kElements; ++i){ out[i] = conj(static_cast(in[i])); } } @@ -87,11 +87,11 @@ class UnaryOp platform::is_same::value, "Unary Operator not supported."); - if( platform::is_same::value ) + if (platform::is_same::value ) { return in; } - else if( platform::is_same::value ) + else if (platform::is_same::value ) { for(int i=0; i < FragmentIn::kElements; ++i){ in[i] = conj(in[i]); @@ -99,9 +99,7 @@ class UnaryOp } return in; } -}; + }; + } + } } -} -} - - diff --git a/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h b/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h new file mode 100644 index 00000000..4338ed05 --- /dev/null +++ b/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -0,0 +1,375 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + It can be used to load the gamma and beta vectors of layernorm which is loop variant. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorAccessIterator +/// +template +class PredicatedScaleBiasVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; + + using AccessType = AlignedArray; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + BytePointer pointer_; + + TensorCoord thread_offset_; + + int problem_size_k_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + bool guard_; + + TensorCoord::Index residue_size_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Extent of tensor + int problem_size_k, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) { + pointer_ = (thread_id < kThreads) + ? reinterpret_cast( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + problem_size_k_ = problem_size_k; + + is_residue_tile_ = true; + + residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous; + + if (residue_size_ == 0) { + residue_size_ = ThreadblockShape::kContiguous; + } + + guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Extent of tensor + int problem_size_k, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorAccessIterator(problem_size_k, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + guard_ = threadIdx.x < kThreads * 2; + + TensorCoord offset = is_residue_tile_ ? + TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0) + : TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); + + thread_offset_ = + thread_offset_ + + offset; + + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + return reinterpret_cast( + pointer_ + + (thread_offset_.contiguous() * sizeof_bits::value / 8)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + guard_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return guard_; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + ///< Extent of tensor + int problem_size_k, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(problem_size_k, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + int problem_size_k, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorAccessIterator(problem_size_k, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h b/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h new file mode 100644 index 00000000..66f4783d --- /dev/null +++ b/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h @@ -0,0 +1,328 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + This can be used to load var and mean vectors in layernorm which is loop invariant. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorIterator +/// +template +class PredicatedScaleBiasVectorIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 1; + + using AccessType = AlignedArray; + + static int const kIterations = WarpShape::kContiguous / 8; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + ConstPointer scale_pointer_; + ConstPointer bias_pointer_; + + /// Size of tensor + int problem_size_; + + int32_t thread_offset_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Extent of tensor + int problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : problem_size_(problem_size), + scale_pointer_(scale_pointer), + bias_pointer_(bias_pointer) { + + thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; + } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Extent of tensor + int problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorIterator(problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.fill(__float2half2_rn(0.0f)); + __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); + + // load scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2].x, + scale_pointer_ + thread_offset_ + c * 8, + (thread_offset_ + c * 8) < problem_size_ + ); + } + + // load bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2 + 1].x, + bias_pointer_ + thread_offset_ + c * 8, + (thread_offset_ + c * 8) < problem_size_ + ); + } + + // duplicate scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2].y = frag_ptr[c * 2].x; + } + + // duplicate bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + using Fragment = typename UnderlyingIterator::Fragment; + + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + ///< Extent of tensor + int problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + int problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorIterator(problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + iterator_.load(frag); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h b/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h similarity index 98% rename from include/cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h rename to include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h index ba8a4bb5..51f8f349 100644 --- a/include/cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h +++ b/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h @@ -47,7 +47,7 @@ //////////////////////////////////////////////////////////////////////////////// namespace cutlass { -namespace conv { +namespace transform { namespace threadblock { //////////////////////////////////////////////////////////////////////////////// @@ -151,7 +151,7 @@ class RegularScaleBiasVectorAccessIterator **Functionality** +# Functionality + - N - Column Major Matrix - T - Row Major matrix - {N,T} x {N,T} - All combinations, i.e. NN, NT, TN, TT @@ -18,8 +20,6 @@ - SpTensorOp - Use Sparse Tensor Core MMA - WmmaTensorOp - Use WMMA abstraction to use Tensor Core MMA -# Functionality - ## Device-level GEMM The following table summarizes device-level GEMM kernels in CUTLASS, organized by opcode class, data type, and layout. diff --git a/media/docs/grouped_scheduler.md b/media/docs/grouped_scheduler.md new file mode 100644 index 00000000..facbd286 --- /dev/null +++ b/media/docs/grouped_scheduler.md @@ -0,0 +1,388 @@ +![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS Grouped Kernel Schedulers") + +[README](/README.md#documentation) > **Grouped Kernel Schedulers** + +# CUTLASS Grouped Kernel Schedulers + +CUTLASS's grouped kernel is a persistent kernel which launches multiple problems (e.g., GEMMs, SYR2Ks) within a +single CUDA kernel launch. + +Unlike a conventional GEMMs in CUTLASS, which launch a number of threadblocks equal to the number +of tiles in the GEMM, CUTLASS grouped kernels typically launch a number of threadblocks that is +fewer than the total number of tiles across all problems in the group. Each threadblock is then +responsible for computing one or more tiles among the problems in the group. The grouped kernel +_scheduler_ (referred to as the _problem visitor_ in code) is responsible for assigning each +threadblock the sequence of tiles that it will compute within the group. + +This document provides background on the functionality of the grouped kernel scheduler, and describes +various optimizations to the grouped kernel scheduler. + +**Outline** + +* [Introduction to Grouped Kernel Schedulers](grouped_scheduler.md#introduction-to-grouped-kernel-schedulers) +* [Grouped GEMM Scheduler](grouped_scheduler.md#grouped-gemm-scheduler) +* [Grouped Rank2K Scheduler](grouped_scheduler.md#grouped-rank2k-scheduler) +* [Scheduler Modes](grouped_scheduler.md#scheduler-modes) +* [Improving Load Balance by Sorting Problems](grouped_scheduler.md#improving-load-balance-by-sorting-problems) + +# Introduction to Grouped Kernel Schedulers +Given a group of problem sizes and a grid of threadblocks, the scheduler's job is to assign +tiles from problems in the group to threadblocks. Threadblocks in a grouped kernel persistently +execute a loop of querying the scheduler for the next tile to compute and performing the +kernel-level operations for that tile (e.g., MMA and epilogue). In pseudocode, this looks as +follows: +```c++ +ProblemVisitor problem_visitor; + +while (problem_visitor.next_tile()) { + // + // Get next tile index from scheduler + // + + // + // Compute MMA and epilogue + // + + // Inform the scheduler that we are done with the current tile + problem_visitor.advance(gridDim.x); +} +``` + +The key functionality of the grouped kernel scheduler lies in the `next_tile()` method, +which determines which tile in the group the calling threadblock should compute next, if any. + +# Grouped GEMM Scheduler +The scheduler used by grouped GEMM assigns tiles in the group to threadblocks in a round-robin +fashion. + +Consider, for example, the threadblock-to-tile mapping that occurs for a group of four GEMMs +each consisting of a grid of 2x2 tiles. Suppose that eight threadblocks are launched. The +figure below illustrates the threadblock ID assigned to each tile in each GEMM in the group. + +![ALT](/media/images/grouped-gemm-schedule-2x2.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four GEMMs with 2x2 grids of tiles") + +A similar mapping for problems that do not have the same number of tiles +is shown below: + +![ALT](/media/images/grouped-gemm-schedule-varied.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four GEMMs with varying tile count") + +## Computing the schedule for a given block +Each threadblock in the grouped GEMM computes its own schedule by calling +the `next_tile()` method described above. + +To do this, the threadblock's `ProblemVisitor` maintains a `thread_idx` +member that is initialized to `blockIdx.x` and is incremented by +`gridDim.x` between each tile computed (only the x dimension is used) +in the launch configuration for grouped kernels). The scheduler must +then figure out which GEMM in the group `tile_idx` belongs to, and which tile +within that problem it maps to. + +1. **Determining which GEMM `tile_idx` maps to:** The scheduler determines +the GEMM to which `tile_idx` belongs by iterating through GEMMs starting with +the most-recently visited GEMM, and adding the number of tiles within that +GEMM to a running variable `problem_tile_start`. The scheduler has found the +correct problem for this tile when `problem_tile_start <= tile_idx < problem_tile_start + tiles_in_problem`. + +2. **Determining the tile within a GEMM `tile_idx` maps to:** Once the GEMM +to which `tile_idx` maps has been located, the specific tile within that +GEMM that this block should compute is given by `tile_idx - problem_tile_start`. +Simple rasterization is then performed to map this one-dimensional tile ID +into the two-dimensional coordinate of the tile to compute in the GEMM. + +We describe how this search is accelerated in [Scheduler Modes](grouped_scheduler.md#scheduler-modes). + +# Grouped Rank2K Scheduler +The previous section described the operation of the scheduler used +for grouped GEMM kernels. While this scheduler is sufficient for +correctly implementing grouped Rank2K operations (i.e., SYR2K and HER2K), it leads to significant inefficiencies. + +We next describe these inefficiencies as well as how the CUTLASS +grouped Rank2K scheduler overcomes them. + +## Inefficiency of grouped GEMM scheduler for grouped Rank2K problems +The grouped GEMM scheduler assumes that every tile in every GEMM in the group will +ultimately affect the output of the problem. This is not the case for Rank2K +problems, for which matrix C is either upper or lower triangular. Using the default +grouped GEMM scheduler for such problems will thus lead to threadblocks frequently +being assigned to tiles that exit early (e.g., due to being assigned to a tile in the +upper-triangular portion of a lower-triangular problem). This further leads to load +imbalance among threadblocks, as the grouped GEMM scheduler assigns nearly the same +number of tiles to all threadblocks, regardless of how many tiles are truly active. + +Consider an example of a group of four SYR2K problems, each with matrix C consisting +of a grid of 2x2 tiles. Matrix C in each problem is lower triangular, indicated by +shaded tiles. Consider that eight threadblocks are launched to compute the grouped +problem. The default grouped GEMM scheduler will assign threadblocks to tiles in the following order: + +![ALT](/media/images/grouped-syr2k-schedule-using-grouped-gemm-scheduler.png "CUTLASS grouped GEMM scheduler assigning threadblocks to four SYR2Ks with 2x2 grids of tiles") + +In this case, threadblocks 1 and 5 are continuously assigned to inactive tiles. In +scenarios in which problems within the group have varying size, we have observed +this to still lead to significant load imbalance. + +## Specializing the scheduler for triangular problems +We seek to design a scheduler that more efficiently maps threadblocks to active tiles +for kernels that use triangular output matrices. The scheduler should ideally assign +threadblocks only to those tiles within lower-triangular portion of a +lower-triangular problem (and vice-versa for upper-triangular problems). + +Using the example above, the resulting assignment of threadblocks to tiles from +such a scheduler might be: + +![ALT](/media/images/grouped-syr2k-schedule-ideal.png "CUTLASS grouped SYR2K scheduler assigning threadblocks to four SYR2Ks with 2x2 grids of tiles") + +Achieving this schedule requires mapping from a threadblock ID to tile coordinates + `(i, j)`. + +We will illustrate this by mapping a lower-triangular matrix with a 3x3 grid. We +first calculate row and column indices assuming one-indexed rows, tiles, and +threadblock IDs, and then subtract one to convert to zero-indexed versions. Our +description borrows heavily from the mapping described [here](https://stackoverflow.com/a/40954159). + +![ALT](/media/images/grouped-syr2k-schedule-3x3.png "CUTLASS grouped SYR2K scheduler assigning threadblocks to one SYR2K with a 3x3 grids of tiles") + +### Calculating row `i` given threadblock ID `t` +For a given row i, all threadblock IDs t in that row satisfy the following: +``` +t <= 1 + 2 + 3 + ... + (i-1) + i +``` + +The closed-form equation for the right-hand side is: `i(i+1)/2`. +Using this, we can solve for `i` given `t`: +``` +t <= i(i+1)/2 +2t <= i^2 + i +2t <= i^2 + i + 0.25 - 0.25 +2t + 0.25 <= i^2 + i + 0.25 +2t + 0.25 <= (i + 0.5)^2 +sqrt(2t + 0.25) - 0.5 <= i +``` + +To account for fractional values, we set: +``` +i = ceil(sqrt(2t + 0.25) - 0.5) +``` + +To turn this into a zero-indexed row and work with zero-indexed `t`, we perform: +``` +i = ceil(sqrt(2(t+1) + 0.25) - 0.5) - 1 + = ceil(sqrt(2t + 2.25) - 0.5) - 1 +``` + +### Calculating column `j` given threadblock ID `t` and row `i` +For a given row `i`, all threadblock IDs `t` in that row also satisfy the following: +``` + t > 1 + 2 + 3 + ... + (i-2) + (i-1) +--> t > i(i-1)/2 +``` + +Threadblock IDs within a given row are sequential, so the one-indexed column ID +for one-indexed threadblock ID `t` and row `i` is: +``` +j = t - (i(i-1)/2) +``` + +The zero-indexed version becomes: +``` +j = (t+1) - (i(i+1)/2) -1 + = t - (i(i+1)/2) +``` + +### Accounting for non-square grids +Though the overall output problem size for Rank2K problems is guaranteed to be square, the +grids used in computing may not be square due to using non-square threadblock shapes. For +example, a threadblock shape of 64x32 operating on a problem of output size 128x128 would +result in a grid of 2x4 tiles. + +This case can be handled by noting that the output resembles a square grid of 2x2 "macro tiles" +each of which contains 2 "true tiles." We can thus first map a threadblock ID to its "macro tile" +using the equations above, and then map it to the "true tile" within its "macro tile." In the example +of a 2x4 grid, this mapping would look as follows: + +![ALT](/media/images/grouped-syr2k-schedule-macro.png "CUTLASS grouped SYR2K scheduler converting a grid into a 'macro grid' for computing tile mappings for non-square grids") + +A zero-indexed threadblock ID `t` is mapped to its "macro tile ID" `t_macro` as: +``` +t_macro = t // r +``` +Where `r` is the ratio of the maximum dimension of the grid to the +minimum dimension of the grid (i.e., `r = 4 / 2 = 2` in the previous example). + +One uses `t_macro` and the calculations above to find the row and column in the square matrix to +obtain `i_macro` and `j_macro` (zero-indexed). The mapping from `(i_macro, j_macro) --> (i, j)` +is simply the following: +``` +if (ThreadblockShape::M > ThreadblockShape::N): + r = ThreadblockShape::M / ThreadblockShape::N + i = i_macro + j = (j_macro * r) + (t % r) +elif (ThreadblockShape::M < ThreadblockShape::N): + r = ThreadblockShape::N / ThreadblockShape::M + i = (i_macro * r) + (t % r) + j = j_macro +else: + i = i_macro + j = j_macro +``` + +### Handling cases with grid dimensions that aren't multiples of each other +Even though threadblock shapes M and N are typically multiples of one another, the grid +for a given problem may not have dimensions of the same ratio as that of the threadblock. +For example, a problem of size 132x132 using a threadblock of shape 64x32 will result +in a grid of 3x5 tiles. In this case, there is not an integer number of "true tiles" +per "macro tile." + +When this scenario arises, we simply pad the larger dimension of the grid such that +there are an integer number of "true tiles" per "macro tile." Thus, the 3x5 grid in +the example above will be treated as a 3x6 grid. Row and column positions for each +tile are calculated as above. Any threadblocks that map to tiles that are outside the +problem range or upper/lower triangular portion (e.g., (2, 5)) will exit early from +this problem and may proceed to the next problem in the group. + +### Handling upper-triangular matrices +The only modification needed for upper-triangular matrices is to swap `i_macro` and `j_macro` in the calculations above. + +# Scheduler modes +The grouped kernel schedulers come with two different modes for finding +the next tile for a block to compute. These techniques are controlled by +the [`cutlass::gemm::kernel::GroupScheduleMode`](../../include/cutlass/gemm/kernel/grouped_problem_visitor.h) enum. +We describe each mode in greater detail below. + +## `GroupScheduleMode::kDeviceOnly` (default) +This scheduler mode performs all scheduling work on the device. It parallelizes +the search for the problem that `tile_idx` maps to by having each thread "own" +a different problem and determine whether `tile_idx` falls within the range of +that problem. + +`GroupScheduleMode::kDeviceOnly` performs this parallelization in a warp-wide +fashion. Each thread in the warp loads a problem size indexed by its lane id and +computes the number of tiles in that problem. A warp-wide prefix sum is used to find +the starting tiles for the set of problems the warp is looking at. At the end of the +prefix sum, each thread holds the starting tile index and tile count for a unique +problem in the group. + +While `tile_idx` remains within the range of the problems currently hosted by the +warp, each thread will check whether `tile_idx` is in the range of its current +problem. The matching problem index and its starting tile are then broadcasted to all +threads in the warp. + +## Precomputing schedules on the host: `GroupScheduleMode::kHostPrecompute` +This scheduler attempts to reduce the amount of scheduling performed on the device +by precomputing on the host the sequence of problems that will +be accessed by each block. As described above, all that is needed to map tile_idx to +the specific tile within a problem to compute is the problem ID and the problem's +starting tile (among all of the tiles in the group). Thus, this scheduler precomputes +the problem index and problem starting tile for each tile computed by each block. + +The schedule for an individual block is represented as an array of +`(problem_idx, problem_starting_tile)` tuples. There is one such array per block. +These arrays are produced on the host and copied over to the device. This +representation is optimized for the case in which blocks compute at most one +tile per problem. When a block computes multiple tiles per problem in the group, +the representation above will result in duplicate entries, and thus will be +suboptimal (e.g., `[(3, 20), (3, 20)]` for a block that computes two tiles in +problem 3, which has starting tile index 20). +We have chosen to use the representation described above because grouped kernels +themselves are typically most beneficial when problem sizes are small, and, thus, +blocks compute at most one tile per problem. + +## Which scheduler mode should I use? +Consider the following questions when deciding which scheduling mode to use: + +### How are the parameters used as input to the grouped kernel (e.g., ptrA, lda) set in my application? +If these are set by a previous kernel running on +the device (rather than by the host), you likely want to use `kDeviceOnly`, +as this will minimize additional host-device communication. + +### Can host-side work be overlapped with other device kernels in my application? +For example, if a grouped GEMM is used as the Nth layer in a neural network, +host-side precomputation for the grouped GEMM can potentially be overlapped +with device-side work for layer N-1. In this case `kHostPrecompute` is likely +a good fit. + +### How compute-intensive are the problems in my group? +The differences in performance between `kHostPrecompute` and `kDeviceOnly` are most +noticeable for grouped kernels with low computational intensity, for which time spent in +the scheduler accounts for a significant fraction of the grouped kernel's runtime. +Intuitively, as problems in a group decrease in computational intensity, a smaller +fraction of the overall runtime will be consumed in performing MMA operations, leading +to a larger fraction of the overall runtime being consumed by scheduling logic. + +Since the scheduling modes affect only the scheduling logic of the grouped kernels, +one expects to see most benefit from `kHostPrecompute` for less computationally-intense +groups. + +# Improving Load Balance by Sorting Problems +The grouped kernel schedulers assign a nearly equal number +of tiles to each block participating in the grouped kernel. Every tile in the +group has the same M and N dimensions. However, the K dimension of each +tile depends on the K dimension of the problem, so tiles may have different +K dimensions. Thus, the K dimension of the +tile plays a significant role in determining how long it takes for a given +tile to be computed. + +## Potential problems with imbalanced K dimension +To ensure that compute load is balanced evenly across blocks, it is important +that the sum of the K dimensions among all tiles a block computes be similar +to that of other blocks; if one block computes far more tiles with a large +value of K than other blocks, it may take longer than the other blocks. + +For example, consider the following group of GEMMs: +``` +0 1152x768x128 +1 1152x768x1024 +2 768x1152x128 +3 768x1152x1024 +``` +If a tile size of 128x128 is used, then each problem will have 54 tiles. +Thus, there are 216 tiles across the group. + +Suppose this grouped GEMM is run on GA100, which has 108 SMs. Suppose that +the occupancy given the parameters of the grouped GEMM is one -- one threadblock +can be active at a time on an SM. The grouped GEMM will, thus, run with 108 +persistent threadblocks, each of which computes (256 / 108) = 2 tiles. + +Under the round-robin assignment of tiles to threadblocks employed by +the grouped GEMM scheduler, the assignment of tiles to threadblocks +in this GEMM will be as follows: +``` +Threadblocks 0-53: Tiles of size 128x128x128 from problem 0 +Threadblocks 54-107: Tiles of size 128x128x1024 from problem 1 +Threadblocks 0-53: Tiles of size 128x128x128 from problem 2 +Threadblocks 54-107: Tiles of size 128x128x1024 from problem 3 +``` + +Following this assignment, threadblocks 54-107 perform significantly more +work than threadblocks 0-53 because they compute two tiles with a K +dimension of 1024, whereas threadblocks 0-53 compute two tiles with K +dimension of only 128. + +Due to this imbalanced assignment, threadblocks 54-107 will run +significantly longer than threadblocks 0-53, leaving threadblocks +0-53 idle for a large fraction of time. + +Clearly, a better assignment of tiles to threadblocks for this +example would involve all threadblocks computing one tile with +a K dimension of 1024 and one tile with a K dimension of 128. +This would better balance the workload among threadblocks. + +## Potential for sorting problems to reduce imbalance +A simple way to potentially reduce load imbalance is to sort the problems in a group in +descending order of their K dimension. This can help to improve load balance +because tiles in a group are assigned in a round-robin fashion to blocks +sequentially, so every block will always be assigned next the tile with +the highest K dimension available. + +Considering the example described above, sorting the problem sizes before +executing grouped GEMM improves the runtime of this grouped GEMM on GA100 with each +scheduling mode by around 30%. + +To ease the process of sorting groups and their associated metadata in this +manner, the device-level grouped kernels provide a `sort_problems()` method. +An example of how to use this may be found in the [grouped GEMM example](../../examples/24_gemm_grouped/gemm_grouped.cu). + +Finally, while sorting problems can be helpful in certain scenarios, it is +not guaranteed to improve performance. In some cases, performance can +decrease when sorting problems due to additional conflicting factors that +affect GEMM performance. We recommend profiling your grouped kernel with +and without sorting to see whether it helps in your case. diff --git a/media/images/grouped-gemm-schedule-2x2.png b/media/images/grouped-gemm-schedule-2x2.png new file mode 100755 index 00000000..27d57497 Binary files /dev/null and b/media/images/grouped-gemm-schedule-2x2.png differ diff --git a/media/images/grouped-gemm-schedule-varied.png b/media/images/grouped-gemm-schedule-varied.png new file mode 100755 index 00000000..47de67dc Binary files /dev/null and b/media/images/grouped-gemm-schedule-varied.png differ diff --git a/media/images/grouped-syr2k-schedule-3x3.png b/media/images/grouped-syr2k-schedule-3x3.png new file mode 100755 index 00000000..e98e24b8 Binary files /dev/null and b/media/images/grouped-syr2k-schedule-3x3.png differ diff --git a/media/images/grouped-syr2k-schedule-ideal.png b/media/images/grouped-syr2k-schedule-ideal.png new file mode 100755 index 00000000..28d398ca Binary files /dev/null and b/media/images/grouped-syr2k-schedule-ideal.png differ diff --git a/media/images/grouped-syr2k-schedule-macro.png b/media/images/grouped-syr2k-schedule-macro.png new file mode 100755 index 00000000..1c2a2f98 Binary files /dev/null and b/media/images/grouped-syr2k-schedule-macro.png differ diff --git a/media/images/grouped-syr2k-schedule-using-grouped-gemm-scheduler.png b/media/images/grouped-syr2k-schedule-using-grouped-gemm-scheduler.png new file mode 100755 index 00000000..24737b5f Binary files /dev/null and b/media/images/grouped-syr2k-schedule-using-grouped-gemm-scheduler.png differ diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index a9bdbd9a..3d3cdc80 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -110,6 +110,7 @@ cutlass_test_unit_add_executable( # F16 conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu + depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu ) if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) @@ -177,12 +178,16 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) # Conv2d (small channel count specializations) conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu - + # Conv2d (Strided Dgrad) conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu + conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu # Conv3d conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu + + # Group Conv2d + group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu ) # Conv - TF32 input, F32 output, F32 accumulation diff --git a/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu index 6e467b85..b31107e1 100644 --- a/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu @@ -109,7 +109,7 @@ std::vector Conv2dFewChannelProblemSizes(int c } //////////////////////////////////////////////////////////////////////////////// -#if 0 + TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8, 128x128_64x3_64x64x64) { @@ -201,7 +201,7 @@ TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_ EXPECT_TRUE(test::conv::device::TestSpecificConv2d( Conv2dFewChannelProblemSizes(kChannelCount))); } -#endif + //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2, diff --git a/test/unit/conv/device/conv2d_problems.h b/test/unit/conv/device/conv2d_problems.h index b740161e..01a90910 100644 --- a/test/unit/conv/device/conv2d_problems.h +++ b/test/unit/conv/device/conv2d_problems.h @@ -684,6 +684,154 @@ struct TestbedConv2dProblemSizes { }; + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedGroupConv2dProblemSizes { + + // + // Data members + // + int threadblock_n; + int threadblock_k; + int minimum_channel_size; + + Conv2dProblemVector default_single_group_sizes; + Conv2dProblemVector default_multiple_group_sizes; + + // + // Methods + // + /// Default ctor + TestbedGroupConv2dProblemSizes( + int threadblock_n_, + int threadblock_k_, + int minimum_channel_size_ = 64) + : threadblock_n (threadblock_n_), + threadblock_k (threadblock_k_), + minimum_channel_size (minimum_channel_size_) { + initialize_group_conv2d_default_sizes(); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &default_single_group_sizes, + &default_multiple_group_sizes + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!((problem.C / problem.groups) % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_group_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////// + // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + // One CTA calculates a single group + //////////////////////////////////////////////////////////////////////////////////// + + for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { + // groups = 2, 3, 4 + for (int groups = 2; groups < 5; ++groups) { + + int conv_k = cta_per_group_k * threadblock_n * groups; + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) + {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + groups // groups + )); + + } // loop groups + } // loop cta_per_group_k + + // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + //////////////////////////////////////////////////////////////////////////////////// + // One CTA calculate multiple groups: CTA::N % k_per_group = 0 + //////////////////////////////////////////////////////////////////////////////////// + + // 2 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 4}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 2 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 4 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 8}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + + // 4 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + } + +}; + + } // namespace device } // namespace conv } // namespace test diff --git a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu index 748f83ed..22a1e77e 100644 --- a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -85,7 +85,7 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32n test::conv::device::Conv2dProblemVector problem_size_list; -#if 0 // run specific problem size in the unit test first +// run specific problem size in the unit test first problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( {1, 4, 4, 8}, // input size (NHWC) {8, 3, 3, 8}, // filter size (KRSC) @@ -93,7 +93,6 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32n {3, 3}, // stride (stride_h, stride_w) {1, 1} // dilation (dilation_h, dilation_w) )); -#endif /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); @@ -281,7 +280,7 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32 test::conv::device::Conv2dProblemVector problem_size_list; -#if 0 // run specific problem size in the unit test first + // run specific problem size in the unit test first problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( {1, 56, 56, 8}, // input size (NHWC) {8, 1, 1, 8}, // filter size (KRSC) @@ -298,8 +297,6 @@ TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32 {1, 1} // dilation (dilation_h, dilation_w) )); -#endif - /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); } diff --git a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu new file mode 100644 index 00000000..0c7c289b --- /dev/null +++ b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align4, + 64x64_32x5_32x32x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::tfloat32_t; + using ElementB = cutlass::tfloat32_t; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 4, + ElementAccumulator, + ElementCompute + >, + cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<>, + 5, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided, + 4, + 4 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + + test::conv::device::Conv2dProblemVector problem_size_list; + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, 16}, // input size (NHWC) + {8, 3, 3, 16}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, 16}, // input size (NHWC) + {8, 3, 3, 16}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 9f0e04f9..0e9ac9a6 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -602,7 +602,7 @@ bool TestAllConv2d( conv_test_sizes, // run user specified sizes conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes -#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled #endif }; @@ -716,7 +716,7 @@ bool TestAllConv2d( return true; } - + // CUTLASS DGRAD's *strided* specialization does not support split-k mode if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad) && diff --git a/test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu b/test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu new file mode 100644 index 00000000..16c93630 --- /dev/null +++ b/test/unit/conv/device/depthwise_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + + +std::vector DepthwiseFpropProblemSizes() { + +std::vector problems; + +for ( int channels = 16; channels < 256 ; channels+=16){ + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, channels}, // input size (NHWC) + {channels, 3, 3, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 1, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, channels}, // input size (NHWC) + {channels, 3, 3, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {2, 2}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 1, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, channels}, // input size (NHWC) + {channels, 7, 7, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 1, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 112, 112, channels}, // input size (NHWC) + {channels, 7, 7, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 1, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 112, 112, channels}, // input size (NHWC) + {channels, 7, 7, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {2, 2} , // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 1, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 112, 112, channels}, // input size (NHWC) + {channels, 5, 5, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 1, // split_k_slices + channels // groups + )); + + problems.push_back(cutlass::conv::Conv2dProblemSize( + {1, 112, 112, channels}, // input size (NHWC) + {channels, 5, 5, 1}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {2, 2} , // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, // Convolution mode + 1, // split_k_slices + channels // groups + )); +} + +return problems; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM60_Device_Depthwise_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 128x128_8x2_64x64x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + + /// Device-level depthwiseFpropKernel instance + using depthwiseFpropKernel = typename cutlass::conv::kernel::DefaultDepthwiseFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm60, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using DepthwiseFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificConv2d( + DepthwiseFpropProblemSizes())); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM60_Device_Depthwise_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_simt_f16, + 64x64_8x2_32x32x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + + /// Device-level depthwiseFpropKernel instance + using depthwiseFpropKernel = typename cutlass::conv::kernel::DefaultDepthwiseFprop< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm60, + cutlass::gemm::GemmShape<64, 64, 8>, + cutlass::gemm::GemmShape<32, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using DepthwiseFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestSpecificConv2d( + DepthwiseFpropProblemSizes())); + +} diff --git a/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu new file mode 100644 index 00000000..46b21647 --- /dev/null +++ b/test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + SingleGroupPerCTA_128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + SingleGroupPerCTA_64x64_64x3_32x32x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_single_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + MultipleGroupPerCTA_128x128_64x3_64x64x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kMultipleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_multiple_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Conv2d_Group_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32, + MutipleGroupPerCTA_64x64_64x3_32x32x64) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + /// Device-level Conv2d instance + using Conv2dGroupFpropKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kMultipleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dGroupFprop = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run group conv unit test sizes with device-level Conv2d instance + test::conv::device::TestbedGroupConv2dProblemSizes problem_sizes( + ThreadblockShape::kN, ThreadblockShape::kK, + 128/cutlass::sizeof_bits::value + ); + EXPECT_TRUE(test::conv::device::TestSpecificConv2d(problem_sizes.default_multiple_group_sizes)); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 2aca5c16..f7977cda 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -340,6 +340,24 @@ cutlass_test_unit_add_executable( gemm_grouped_sm80.cu ) +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_grouped_scheduler + + BATCH_SOURCES ON + BATCH_SIZE 4 + + gemm_grouped_scheduler_sm80.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_grouped_rank_2k_scheduler + + BATCH_SOURCES ON + BATCH_SIZE 4 + + rank_2k_grouped_scheduler_sm80.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_sparse_tensorop_sm80 @@ -540,4 +558,27 @@ cutlass_test_unit_add_executable( hemm_cf32h_cf32n_tensor_op_fast_f32_rs_sm80.cu ) +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_grouped_blas3 + + BATCH_SOURCES ON + BATCH_SIZE 4 + + # Grouped SYR2K SM80 f64 tests + syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu + syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu + syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu + syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu + + # Grouped SYR2K SM80 cf64 tests + syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu + syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu + syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu + syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu + + # Grouped HER2K SM80 f64 tests + her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu + her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu +) + endif() diff --git a/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu b/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu new file mode 100644 index 00000000..86fc083d --- /dev/null +++ b/test/unit/gemm/device/gemm_grouped_scheduler_sm80.cu @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped GEMM problem visitors +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include "testbed_grouped_scheduler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Run a series of tests on the testbed +template +void run_tests() { + for (int scale_factor : {8, 16, 32, 64}) { + for (int threadblock_count : {54, 108, 216, 324, 432}) { + for (int problems : {1, 27, 180, 300}) { + Testbed testbed; + testbed.run(problems, threadblock_count, scale_factor); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p128_t128, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 128; + static int const kThreadCount = 128; + static bool const kTranspose = false; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p128_t128_transpose, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 128; + static int const kThreadCount = 128; + static bool const kTranspose = true; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static bool const kTranspose = false; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p256_t128, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 128; + static bool const kTranspose = false; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 64x32x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static bool const kTranspose = false; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p256_t256_transpose, 64x32x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static bool const kTranspose = true; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p256_t256, 32x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static bool const kTranspose = false; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmGroupedScheduler_p256_t256_transpose, 32x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static bool const kTranspose = true; + + using Testbed = test::gemm::device::TestbedGroupedGemmScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kTranspose, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_grouped_sm80.cu b/test/unit/gemm/device/gemm_grouped_sm80.cu index 921161e4..568292ba 100644 --- a/test/unit/gemm/device/gemm_grouped_sm80.cu +++ b/test/unit/gemm/device/gemm_grouped_sm80.cu @@ -181,7 +181,7 @@ struct GemmGroupedProblemVisitor { } CUTLASS_HOST_DEVICE - int64_t threadblock_index() const { + int64_t threadblock_idx() const { return tile_idx - problem_tile_start; } @@ -193,7 +193,7 @@ struct GemmGroupedProblemVisitor { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template __global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { __shared__ GemmGroupedProblemVisitor::SharedStorage shared_storage; @@ -201,18 +201,18 @@ __global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { GemmGroupedProblemVisitor problem_visitor( shared_storage, params, - {CtaShapeM, CtaShapeN}, + {ThreadblockShapeM, ThreadblockShapeN}, blockIdx.x); while (problem_visitor.next_tile()) { cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); - int64_t cta_idx = problem_visitor.threadblock_index(); + int64_t threadblock_idx = problem_visitor.threadblock_idx(); cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - int cta_tile_m_idx = int(cta_idx / grid_shape.n()); - int cta_tile_n_idx = int(cta_idx % grid_shape.n()); + int threadblock_tile_m_idx = int(threadblock_idx / grid_shape.n()); + int threadblock_tile_n_idx = int(threadblock_idx % grid_shape.n()); // // Do the MMA @@ -220,13 +220,13 @@ __global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { if (threadIdx.x == 0) { #if 0 - printf("Block %d - tile: %lld, problem %d, cta_idx: %lld, cta(m: %d, n: %d)\n", + printf("Block %d - tile: %lld, problem %d, threadblock_idx: %lld, threadblock(m: %d, n: %d)\n", blockIdx.x, problem_visitor.tile_index(), problem_visitor.problem_index(), - cta_idx, - cta_tile_m_idx, - cta_tile_n_idx); + threadblock_idx, + threadblock_tile_m_idx, + threadblock_tile_n_idx); #endif } @@ -241,8 +241,8 @@ TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) { int32_t problem_count = 16; - int const kCtaShapeM = 64; - int const kCtaShapeN = 64; + int const kThreadblockShapeM = 64; + int const kThreadblockShapeN = 64; std::vector problem_sizes(problem_count); std::vector tile_counts(problem_count); @@ -262,7 +262,7 @@ TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) { for (int32_t i = 0; i < problem_count; ++i) { cutlass::gemm::GemmCoord grid_shape = GemmGroupedProblemVisitor::grid_shape( - problem_sizes.at(i), {kCtaShapeM, kCtaShapeN}); + problem_sizes.at(i), {kThreadblockShapeM, kThreadblockShapeN}); int32_t problem_tile_count = (grid_shape.m() * grid_shape.n()); @@ -295,7 +295,7 @@ TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) { dim3 grid(108, 1, 1); dim3 block(128, 1, 1); - GroupedBatchedKernel<<< grid, block >>>(params); + GroupedBatchedKernel<<< grid, block >>>(params); // wait cudaDeviceSynchronize(); @@ -705,6 +705,7 @@ TEST(SM80_Device_GemmGrouped_cf32n_cf32n_cf32n_tensorop_f32, 64x64x16_32x32x16) ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, cutlass::arch::OpMultiplyAddComplex>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; @@ -748,6 +749,7 @@ TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32n_tensorop_f32, 64x64x16_32x32x16) ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, cutlass::arch::OpMultiplyAddComplex>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; @@ -791,6 +793,7 @@ TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32t_tensorop_f32, 64x64x16_32x32x16) ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, cutlass::arch::OpMultiplyAddComplex>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; @@ -834,6 +837,7 @@ TEST(SM80_Device_GemmGrouped_cf32t_cf32h_cf32n_tensorop_f32, 64x64x16_16x16x16) ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 3, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, cutlass::arch::OpMultiplyAddComplex>::GemmKernel; using Gemm = cutlass::gemm::device::GemmGrouped; diff --git a/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu b/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu index e3b50da1..9837390e 100644 --- a/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu +++ b/test/unit/gemm/device/gemm_universal_f16n_f16t_f32t_tensor_op_f32_sm75.cu @@ -79,7 +79,6 @@ TEST(SM75_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32) EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); } - TEST(SM75_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32_updated_batch_count) { using ElementOutput = float; @@ -114,4 +113,3 @@ TEST(SM75_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32, 64x64x32_32x32x32_u #endif // #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// - diff --git a/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..5bfa71c8 --- /dev/null +++ b/test/unit/gemm/device/her2k_cf64h_cf64n_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,310 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +// NOTE: HER2K requires that LayoutA == LayoutB, and that LayoutC == ColumnMajor + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_l_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64h_cf64n_u_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kConjugate, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kConjugate, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..e82d7851 --- /dev/null +++ b/test/unit/gemm/device/her2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,310 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +// NOTE: HER2K requires that LayoutA == LayoutB, and that LayoutC == ColumnMajor + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_l_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Her2KGrouped_cf64n_cf64n_u_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kHermitian>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu b/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu new file mode 100644 index 00000000..6b60986a --- /dev/null +++ b/test/unit/gemm/device/rank_2k_grouped_scheduler_sm80.cu @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K problem visitors +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include "testbed_grouped_rank_2k_scheduler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Run a series of tests on the testbed +template +void run_tests(bool skip_tile_check=false) { + for (int scale_factor : {8, 16, 32, 64}) { + for (int threadblock_count : {54, 108, 216, 324, 432}) { + for (int problems : {1, 27, 180, 300}) { + Testbed testbed(skip_tile_check); + testbed.run(problems, threadblock_count, scale_factor); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p128_t128_l, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 128; + static int const kThreadCount = 128; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p128_t128_u, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 128; + static int const kThreadCount = 128; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p256_t128_l, 64x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 128; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + run_tests(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 64x32x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + + // Skip individual tile check for the non-square SYR2K versions. We still + // compare the problem visitors with one another + run_tests(/*skip_tile_check=*/true); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_u, 64x32x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + + // Skip individual tile check for the non-square SYR2K versions. We still + // compare the problem visitors with one another + run_tests(/*skip_tile_check=*/true); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_l, 32x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kLower; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + + // Skip individual tile check for the non-square SYR2K versions. We still + // compare the problem visitors with one another + run_tests(/*skip_tile_check=*/true); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Rank2KGroupedScheduler_p256_t256_u, 32x64x32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 32>; + static int const kNumPrefetch = 256; + static int const kThreadCount = 256; + static cutlass::FillMode const kFillModeC = cutlass::FillMode::kUpper; + + using Testbed = test::gemm::device::TestbedGroupedRank2KScheduler< + ThreadblockShape, + kNumPrefetch, + kThreadCount, + kFillModeC, + // List of GroupScheduleModes to compare. List must contain at least two. + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>; + + // Skip individual tile check for the non-square SYR2K versions. We still + // compare the problem visitors with one another + run_tests(/*skip_tile_check=*/true); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..46280ed0 --- /dev/null +++ b/test/unit/gemm/device/syr2k_cf64n_cf64n_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 64x32x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_cf64, 32x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 32x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_cf64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..e0342fee --- /dev/null +++ b/test/unit/gemm/device/syr2k_cf64n_cf64t_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..92b501ed --- /dev/null +++ b/test/unit/gemm/device/syr2k_cf64t_cf64n_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64n_cf64n_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..7041add0 --- /dev/null +++ b/test/unit/gemm/device/syr2k_cf64t_cf64t_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_cf64t_cf64t_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = cutlass::complex; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::complex; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = cutlass::complex; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = cutlass::complex; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAddComplex, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..9d6b45ac --- /dev/null +++ b/test/unit/gemm/device/syr2k_f64n_f64n_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,483 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64n_u_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..e9ed3ee6 --- /dev/null +++ b/test/unit/gemm/device/syr2k_f64n_f64t_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,273 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64n_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..97f69ffa --- /dev/null +++ b/test/unit/gemm/device/syr2k_f64t_f64n_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64n_l_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64n_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64n_u_tensor_op_f64, 64x32x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu b/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu new file mode 100644 index 00000000..bb94dcaa --- /dev/null +++ b/test/unit/gemm/device/syr2k_f64t_f64t_tensor_op_f64_grouped_sm80.cu @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/blas3.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_grouped_rank_2k.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 64x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 16>, + cutlass::gemm::GemmShape<64, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64t_l_tensor_op_f64, 128x128x16_32x64x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kLower, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 16>, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64t_u_tensor_op_f64, 32x32x16_16x16x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<16, 16, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Syr2kGrouped_f64t_f64t_u_tensor_op_f64, 32x64x16_32x32x16) { + + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = double; + + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 1, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 1, + ElementC, LayoutC, cutlass::FillMode::kUpper, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<32, 64, 16>, + cutlass::gemm::GemmShape<32, 32, 16>, + cutlass::gemm::GemmShape<8, 8, 4>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, // kStages + cutlass::arch::OpMultiplyAdd, + cutlass::BlasMode::kSymmetric>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + test::gemm::device::TestbedGrouped testbed; + bool passed = testbed.run(24); + EXPECT_TRUE(passed); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/testbed_grouped.h b/test/unit/gemm/device/testbed_grouped.h index 5ec41618..cfcd4268 100644 --- a/test/unit/gemm/device/testbed_grouped.h +++ b/test/unit/gemm/device/testbed_grouped.h @@ -417,46 +417,27 @@ struct TestbedGrouped { return passed; } - /// Returns the number of threadblocks to launch if the kernel can run on the target - /// device. Otherwise, returns zero. - int sufficient() const { - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } - - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - int occupancy = Gemm::maximum_active_blocks(); - - return properties.multiProcessorCount * occupancy; - } - /// Executes one test bool run( int problem_count, ElementCompute alpha = ElementCompute(1), ElementCompute beta = ElementCompute(0)) { - int threadblock_count = sufficient(); - - // Early exit - if (!threadblock_count) { - return false; - } - this->problem_count = problem_count; // Initialize the problem initialize(); + int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + // Configure the GEMM arguments typename EpilogueOutputOp::Params epilogue_op(alpha, beta); @@ -473,13 +454,17 @@ struct TestbedGrouped { lda.get(), ldb.get(), ldc.get(), - ldd.get() + ldd.get(), + problem_sizes_host.data() ); // Initialize the GEMM object Gemm gemm; - cutlass::Status status = gemm.initialize(args); + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = gemm.initialize(args, workspace.get()); if (status != cutlass::Status::kSuccess) { return false; diff --git a/test/unit/gemm/device/testbed_grouped_rank_2k.h b/test/unit/gemm/device/testbed_grouped_rank_2k.h new file mode 100644 index 00000000..b82c9f3d --- /dev/null +++ b/test/unit/gemm/device/testbed_grouped_rank_2k.h @@ -0,0 +1,502 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Rank2K::LayoutA; + using LayoutB = typename Rank2K::LayoutB; + using LayoutC = typename Rank2K::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + auto N = 8 * (rand() % 64) + 24; + auto K = 8 * (rand() % 64) + 24; + cutlass::gemm::GemmCoord problem(N, N, K); + + if (!i) { + problem = cutlass::gemm::GemmCoord(16, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.n() * problem.k(); + int64_t elements_B = problem.n() * problem.k(); + int64_t elements_C = problem.n() * problem.n(); + int64_t elements_D = problem.n() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a Rank2K + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference Rank2K + cutlass::reference::host::Rank2KComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Rank2K::kTransformA, + view_B, + Rank2K::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the Rank2K arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure Rank2K arguments + typename Rank2K::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the Rank2K object + Rank2K rank2k; + + size_t workspace_size = rank2k.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = rank2k.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the Rank2K object + status = rank2k.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h new file mode 100644 index 00000000..76287865 --- /dev/null +++ b/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/device_kernel.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static cutlass::FillMode const kFillModeC = FillModeC; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} + + CUTLASS_DEVICE + cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { + int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; + int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; + int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); + + if (FillModeC == cutlass::FillMode::kUpper) { + cutlass::swap(macro_row, macro_col); + } + + int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); + int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); + + return cutlass::gemm::GemmCoord(row, col, 0); + } +}; + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); + + problem_visitor.advance(gridDim.x); + + // + // Early exit conditions + // 1) Out of range + // 2) Upper-triangular block in lower-triangular problem + // 3) Lower-triangular block in upper-triangular problem + // + + if (grid_shape.m() <= tile_offset.m() || + grid_shape.n() <= tile_offset.n()) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && + (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && + tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + if (skip_tile_check) { + return true; + } + + return verify(); + } +}; + +template +struct TestbedGroupedRank2KScheduler { + + using BaselinePV = BaselineProblemVisitor, + ThreadblockShape, + PrefetchTileCount, + ThreadCount, + FillModeC>; + + // + // Data members + // + + // Whether to skip checking that the tiles are visited as expected. This is useful + // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped + // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to + // exit early, but which are difficult to detect in tests without reimplementing + // this functionality. + bool skip_tile_check; + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): + skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + int n = scale_factor * (rand() % 64) + 24; + + cutlass::gemm::GemmCoord problem( + n, + n, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + FillModeC>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run(skip_tile_check)); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run(skip_tile_check)); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/testbed_grouped_scheduler.h b/test/unit/gemm/device/testbed_grouped_scheduler.h new file mode 100644 index 00000000..0bd409e0 --- /dev/null +++ b/test/unit/gemm/device/testbed_grouped_scheduler.h @@ -0,0 +1,406 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped GEMM problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + problem_visitor.advance(gridDim.x); + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + return verify(); + } +}; + +template +struct TestbedGroupedGemmScheduler { + + using BaselinePV = BaselineProblemVisitor, + ThreadblockShape, + PrefetchTileCount, + ThreadCount>; + + // + // Data members + // + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): + seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + Transpose>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run()); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run()); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index 095fbb91..45786089 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -149,6 +149,35 @@ class GemmOperation: ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' return self.procedural_name() + +################################################################################################### +# +# Data structure modeling a grouped GEMM operation +# +################################################################################################### + +# +class GroupedGemmOperation(GemmOperation): + # + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + scheduler_mode = GroupScheduleMode.Device): + super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor, swizzling_functor) + + self.scheduler_mode = scheduler_mode + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + base = super().procedural_name() + return SubstituteTemplate( + base + "_schedule${schedule}", + { + 'schedule': ShortGroupScheduleModeNames[self.scheduler_mode] + }) + + ################################################################################################### # # Emits single instances of a CUTLASS device-wide operator @@ -738,6 +767,7 @@ using ${operation_name}_base = ${epilogue_functor}, ${swizzling_functor}, ${stages}, + ${scheduler_mode}, ${math_operation} >::GemmKernel; @@ -817,6 +847,7 @@ ${compile_guard_end} 'align_b': str(operation.B.alignment), 'transform_a': ComplexTransformTag[operation.A.complex_transform], 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode], 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] } diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index bc03f2ef..f0d84699 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -180,7 +180,7 @@ def CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, \ B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) C = TensorDescription(element_c, layout[2], alignment_c) - new_operation = GemmOperation(GemmKind.Grouped, tile_description.minimum_compute_capability, \ + new_operation = GroupedGemmOperation(GemmKind.Grouped, tile_description.minimum_compute_capability, \ tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) manifest.append(new_operation) @@ -346,7 +346,7 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme # iterator algorithm (analytic and optimized) #iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] iterator_algorithms = [IteratorAlgorithm.Optimized] - + # by default, only generate the largest tile size, largest alignment, and optimized iterator if manifest.kernel_filter == '': tile_descriptions = [tile_descriptions[0],] @@ -527,7 +527,7 @@ def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignme alignment_c = min(8, alignment) # iterator algorithm (analytic and optimized) - #iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] +# iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] iterator_algorithms = [IteratorAlgorithm.Optimized] # by default, only generate the largest tile size and optimized iterators @@ -1677,7 +1677,6 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [8, 4, 2] @@ -1694,12 +1693,14 @@ def GenerateSM80_TensorOp_16816(manifest, cuda_version): TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), @@ -1773,23 +1774,22 @@ def GenerateSM80_SparseTensorOp_16832(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [8] for math_inst in math_instructions: tile_descriptions = [ TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), ] @@ -1917,7 +1917,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 + smem_usage = 164 alignment_constraints = [16,] @@ -1931,10 +1931,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), @@ -1986,22 +1986,21 @@ def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [16,] tile_descriptions = [ TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), ] @@ -2102,8 +2101,6 @@ def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 - alignment_constraints = [32,] for math_inst in math_instructions: @@ -2116,11 +2113,11 @@ def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version): TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2173,21 +2170,19 @@ def GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 - alignment_constraints = [32,] tile_descriptions = [ TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), ] @@ -2338,7 +2333,6 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [4, 2, 1] @@ -2354,11 +2348,11 @@ def GenerateSM80_TensorOp_1688(manifest, cuda_version): TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2424,7 +2418,6 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [4, 2, 1] @@ -2440,11 +2433,11 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2483,7 +2476,6 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [4, 2, 1] @@ -2497,8 +2489,8 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): TileDescription([ 64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([ 64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), @@ -2583,23 +2575,22 @@ def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [4] for math_inst in math_instructions: tile_descriptions = [ TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited), - TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc_smem_limited), + TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), ] @@ -3047,7 +3038,6 @@ def GenerateSM80_TensorOp_884(manifest, cuda_version): min_cc = 80 max_cc = 1024 - max_cc_smem_limited = 80 alignment_constraints = [1,] diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index a6b78324..93c93ec6 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -456,7 +456,7 @@ OperationKindNames = { # class Target(enum.Enum): library = enum_auto() - +# ArchitectureNames = { 50: 'maxwell', 60: 'pascal', @@ -466,6 +466,16 @@ ArchitectureNames = { 80: 'ampere', } +# +SharedMemPerCC = { + 70: 96, # 96KB of SMEM + 72: 96, # 96KB of SMEM + 75: 64, # 64KB of SMEM + 80: 160, # 164KB of SMEM - 4KB reserved for the driver + 86: 100, # 100KB of SMEM + 87: 160, # 164KB of SMEM - 4KB reserved for the driver +} + ################################################################################################### # @@ -564,6 +574,23 @@ SwizzlingFunctorTag = { SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', } +# +class GroupScheduleMode(enum.Enum): + Device = enum_auto(), + Host = enum_auto() + +# +GroupScheduleModeTag = { + GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', + GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' +} + +# +ShortGroupScheduleModeNames = { + GroupScheduleMode.Device: 'Device', + GroupScheduleMode.Host: 'Host' +} + ################################################################################################### # @@ -636,7 +663,6 @@ class MathInstruction: self.opcode_class = opcode_class self.math_operation = math_operation - # class TileDescription: @@ -681,3 +707,29 @@ class TriangularTensorDescription: self.complex_transform = complex_transform ################################################################################################### + +# +def CalculateSmemUsage(operation): + cta_shape = operation.tile_description.threadblock_shape + stages = operation.tile_description.stages + + if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: + # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) + if DataTypeSize[operation.A.element] == 32: + elements_per_8b_md = 2 + elif DataTypeSize[operation.A.element] == 4: + elements_per_8b_md = 8 + else: + elements_per_8b_md = 4 + + smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md + else: + # Few BLAS3 operations only have A tensor + smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \ + DataTypeSize[operation.A.element] * cta_shape[1] * cta_shape[2] // 8 + + smem_usage = smem_per_stage * stages + return (smem_usage >> 10) +################################################################################################### diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 408ab1f5..d5cbf614 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -276,7 +276,8 @@ class Manifest: for cc in self.compute_capabilities: if cc >= operation.tile_description.minimum_compute_capability and \ - cc <= operation.tile_description.maximum_compute_capability: + cc <= operation.tile_description.maximum_compute_capability and \ + (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)): enabled = True break diff --git a/tools/library/scripts/pycutlass/README.md b/tools/library/scripts/pycutlass/README.md new file mode 100644 index 00000000..8d4f9279 --- /dev/null +++ b/tools/library/scripts/pycutlass/README.md @@ -0,0 +1,120 @@ +# PyCUTLASS: CUTLASS Python Interface + +PyCUTLASS is a python interface of CUTLASS C++ template library. PyCUTLASS takes user-defined operation descriptions, emits C++ code, and compiles it with `nvcc` or `nvrtc`. It also provides wrappers for user-provide arguments from [numpy](https://numpy.org/), [torch](https://pytorch.org/), and [cupy](https://github.com/cupy/cupy) and encode them to kernel's parameters. + +```python +import pycutlass +from pycutlass import * +import torch + +pycutlass.get_memory_pool(2**8, 2**32) + +math_inst = MathInstruction( + [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, + cutlass.OpClass.Simt, MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 8], 4, [2, 4, 1], + math_inst, 80, 80 +) + +A = TensorDescription( + cutlass.float32, cutlass.RowMajor, 1 +) + +B = TensorDescription( + cutlass.float32, cutlass.RowMajor, 1 +) + +C = TensorDescription( + cutlass.float32, cutlass.RowMajor, 1 +) + +operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=cutlass.float32, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 +) + +pycutlass.compiler.add_module([operation,]) + +problem_size = cutlass.gemm.GemmCoord(512, 256, 128) + +tensor_A = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.k()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) +tensor_B = torch.ceil(torch.empty(size=(problem_size.k(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) +tensor_C = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) +tensor_D = torch.empty_like(tensor_C) + + +alpha = 1.0 +beta = 0.0 + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=LinearCombinationFunctorArguments(alpha, beta), + gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 +) + +operation.run(arguments) + +arguments.sync() + +tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C + +assert torch.equal(tensor_D, tensor_D_ref) +``` +PyCUTLASS also provides infrastructures for profiling, compiled artifact management, and pool memory manager + +## Installation + +### Using Docker +You can run the PyCUTLASS on NGC pytorch container. +```shell +docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.08-py3 +``` +PyCUTLASS requires additional dependency Boost C++ library, which can be installed with +```bash +apt-get update +apt-get -y install libboost-all-dev +``` + + + +### Environment variables +PyCUTLASSS requires two environment variables: +* `CUTLASS_PATH`: the root directory of CUTLASS +* `CUDA_INSTALL_PATH`: the directory where cuda toolkit is installed + +After setting these two environment variables, PyCUTLASS can be installed with +```shell +cd $CUTLASS_PATH/tools/library/scripts/pycutlass && bash build.sh +``` + +## Examples +Examples can be found in `$CUTLASS_PATH/examples/40_cutlass_py` + +## Test +The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with +```shell +cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py +``` + + +## Troubleshooting + +### Issue 1: permission denied +Building PyCUTLASS requires installing dependencies to python. So conda could an option if you don't have permission. + +### Issue 2: rmm: module not found +PyCUTLASS manages the device memory with [RMM](https://github.com/rapidsai/rmm). Our `build.sh` automatically pull the [rmm branch-22.08](https://github.com/rapidsai/rmm/tree/branch-22.08) from github and build it from source. The rmm is allocated at `$CUTLASS_PATH/tools/library/scripts/pycutlass/rmm`. It requires `cmake > 3.20.1`. If the build fails, it can be manually fixed with the following steps: +```shell +cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm && ./build.sh librmm rmm + +cd $CUTLASS_PATH/tools/library/scripts/pycutlass/rmm/python +python setup.py build_ext --inplace +python setup.py install +``` +To test whether rmm is successfully installed, try `import rmm`. For other issues related to rmm, please check https://github.com/rapidsai/rmm/issues. diff --git a/tools/library/scripts/pycutlass/build.sh b/tools/library/scripts/pycutlass/build.sh new file mode 100644 index 00000000..cffc85a6 --- /dev/null +++ b/tools/library/scripts/pycutlass/build.sh @@ -0,0 +1,4 @@ +pip install pybind11 +git clone https://github.com/google/googletest.git +python setup.py install +python setup.py rmm diff --git a/tools/library/scripts/pycutlass/build_doc.sh b/tools/library/scripts/pycutlass/build_doc.sh new file mode 100644 index 00000000..def7c773 --- /dev/null +++ b/tools/library/scripts/pycutlass/build_doc.sh @@ -0,0 +1,2 @@ +python setup.py develop +sphinx-build -b html docs/source/ docs/build/html diff --git a/tools/library/scripts/pycutlass/docs/Makefile b/tools/library/scripts/pycutlass/docs/Makefile new file mode 100644 index 00000000..94cd4eb0 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/Makefile @@ -0,0 +1,52 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/tools/library/scripts/pycutlass/docs/make.bat b/tools/library/scripts/pycutlass/docs/make.bat new file mode 100644 index 00000000..061f32f9 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/tools/library/scripts/pycutlass/docs/source/conf.py b/tools/library/scripts/pycutlass/docs/source/conf.py new file mode 100644 index 00000000..73ec0687 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/conf.py @@ -0,0 +1,93 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = 'PyCutlass' +copyright = '2022, Andrew Kerr; Zhaodong Chen; Haicheng Wu; Szymon Migacz; Graham Markall' +author = 'Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.duration', + 'sphinx.ext.doctest', + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'enum_tools.autoenum', + 'sphinx.ext.autosummary' +] + +autosummary_generate = True +autosummary_imported_members = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'classic' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ['_static'] diff --git a/tools/library/scripts/pycutlass/docs/source/conv2d_op.rst b/tools/library/scripts/pycutlass/docs/source/conv2d_op.rst new file mode 100644 index 00000000..7ce0510d --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/conv2d_op.rst @@ -0,0 +1,13 @@ +CONV2D Operation +================ + +.. autoclass:: pycutlass.Conv2dOperation + :special-members: + :members: run + :exclude-members: __weakref__, configuration_name, core_name, extended_name, procedural_name + +.. autoclass:: pycutlass.Conv2dArguments + :special-members: + :members: + :exclude-members: initialize + :show-inheritance: diff --git a/tools/library/scripts/pycutlass/docs/source/cutlass.rst b/tools/library/scripts/pycutlass/docs/source/cutlass.rst new file mode 100644 index 00000000..6ec68253 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/cutlass.rst @@ -0,0 +1,2 @@ +cutlass +======= diff --git a/tools/library/scripts/pycutlass/docs/source/descriptor.rst b/tools/library/scripts/pycutlass/docs/source/descriptor.rst new file mode 100644 index 00000000..cd0a3b98 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/descriptor.rst @@ -0,0 +1,6 @@ +Descriptions +============== + +.. autoclass:: pycutlass.TileDescription + :special-members: + :members: diff --git a/tools/library/scripts/pycutlass/docs/source/frontend.rst b/tools/library/scripts/pycutlass/docs/source/frontend.rst new file mode 100644 index 00000000..1da97eeb --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/frontend.rst @@ -0,0 +1,5 @@ +Frontend +============== + +.. autoclass:: pycutlass.NumpyFrontend + :members: diff --git a/tools/library/scripts/pycutlass/docs/source/gemm_op.rst b/tools/library/scripts/pycutlass/docs/source/gemm_op.rst new file mode 100644 index 00000000..e4bcd8b4 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/gemm_op.rst @@ -0,0 +1,18 @@ +GEMM Operation +============== + +.. autoclass:: pycutlass.GemmOperationUniversal + :special-members: + :members: + +.. autoclass:: pycutlass.GemmOperationGrouped + :special-members: + :members: + +.. autoclass:: pycutlass.GemmArguments + :special-members: + :members: + +.. autoclass:: pycutlass.GemmGroupedArguments + :special-members: + :members: diff --git a/tools/library/scripts/pycutlass/docs/source/index.rst b/tools/library/scripts/pycutlass/docs/source/index.rst new file mode 100644 index 00000000..5e2fa7ad --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/index.rst @@ -0,0 +1,29 @@ +.. PyCutlass documentation master file, created by + sphinx-quickstart on Sun Jun 19 12:05:42 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to PyCutlass's documentation! +===================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + + +.. toctree:: + types + cutlass + descriptor + frontend + gemm_op + conv2d_op diff --git a/tools/library/scripts/pycutlass/docs/source/types.rst b/tools/library/scripts/pycutlass/docs/source/types.rst new file mode 100644 index 00000000..15893511 --- /dev/null +++ b/tools/library/scripts/pycutlass/docs/source/types.rst @@ -0,0 +1,6 @@ +Types +======== + + +.. autoenum:: pycutlass.OperationKind + :members: diff --git a/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py b/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py new file mode 100644 index 00000000..16e04cca --- /dev/null +++ b/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py @@ -0,0 +1,104 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from pycutlass import * +import pycutlass +from pycutlass.test.conv2d_testbed import Conv2dLauncher + + +if __name__ == "__main__": + pycutlass.get_memory_pool(2**33, 2**33) + pycutlass.compiler.nvcc() + + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=4, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + profiler = Conv2dLauncher(operation, verification=False, profiling=True) + + python_runtime = profiler.run( + problem_size = cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(32, 224, 224, 128), + cutlass.Tensor4DCoord(128, 3, 3, 128), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), split_k_mode=cutlass.conv.SplitKMode.Serial + ) + + + cpp_runtime = profiler.run_cutlass_profiler( + problem_size = cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(32, 224, 224, 128), + cutlass.Tensor4DCoord(128, 3, 3, 128), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), split_k_mode=cutlass.conv.SplitKMode.Serial + ) + + print(cpp_runtime / python_runtime) diff --git a/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py b/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py new file mode 100644 index 00000000..31f52546 --- /dev/null +++ b/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py @@ -0,0 +1,91 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pycutlass +from pycutlass import * +from pycutlass.test import * +from pycutlass.test.gemm_testbed import GemmUniversalLauncher + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**32, 2**32) + pycutlass.compiler.nvcc() + + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[256, 128, 32], + stages=3, warp_count=[4, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=4 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=4 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + profiler = GemmUniversalLauncher(operation, verification=False, profiling=True) + python_runtime = profiler.run( + mode=cutlass.gemm.Mode.Gemm, + problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096) + ) + + cpp_runtime = profiler.run_cutlass_profiler( + mode=cutlass.gemm.Mode.Gemm, + problem_size=cutlass.gemm.GemmCoord(4096, 4096, 4096), + ) + + print(cpp_runtime / python_runtime) diff --git a/tools/library/scripts/pycutlass/pyproject.toml b/tools/library/scripts/pycutlass/pyproject.toml new file mode 100644 index 00000000..e192f102 --- /dev/null +++ b/tools/library/scripts/pycutlass/pyproject.toml @@ -0,0 +1,9 @@ +[build-system] + +requires = [ + "setuptools", + "scikit-build>0.13.1", + "pybind11", + "numpy<1.23", + "cmake>=3.20.1,!=3.23.0" +] diff --git a/tools/library/scripts/pycutlass/setup.py b/tools/library/scripts/pycutlass/setup.py new file mode 100644 index 00000000..c3933455 --- /dev/null +++ b/tools/library/scripts/pycutlass/setup.py @@ -0,0 +1,79 @@ +import distutils.cmd +from setuptools import setup +import setuptools.command.build_py +import os + +# build rmm dependency +class BuildRMM(distutils.cmd.Command): + user_options = [] + def initialize_options(self): + pass + def finalize_options(self): + pass + def run(self): + try: + import rmm + except ImportError: + print("installing rmm") + os.system("git clone -b branch-22.08 --recurse-submodules https://github.com/rapidsai/rmm.git") + os.chdir("./rmm") + os.system("./build.sh librmm rmm") + os.chdir("./python") + os.system("python setup.py build_ext --inplace") + os.system("python setup.py install") + +cutlass_path = os.getenv('CUTLASS_PATH') +assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." +cuda_install_path = os.getenv('CUDA_INSTALL_PATH') +assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." + +ext_modules = [] + +try: + from pybind11.setup_helpers import Pybind11Extension, build_ext + include_dirs = [ + cutlass_path + "/include", + cuda_install_path + "/include", + cutlass_path + "/tools/util/include", + cutlass_path + "/test", + cutlass_path + "/tools/library/scripts/pycutlass/googletest/googletest/include" + ] + + ext_modules = [ + Pybind11Extension("cutlass", + ["src/cpp/cutlass.cpp"], + include_dirs=include_dirs, + extra_compile_args=["-fpermissive"]) + ] +except ImportError: + pass + +setup( + name="PyCutlass", + version="0.0.1", + author="Zhaodong Chen; Andrew Kerr; Haicheng Wu; Szymon Migacz; Graham Markall", + author_email="zhaodongc@nvidia.com", + description="Python interface for CUTLASS", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + package_dir={"": "src"}, + packages=['pycutlass', 'pycutlass.utils', 'pycutlass.test'], + setup_requires=["pybind11", "numpy<1.23"], + install_requires=[ + "numpy<1.23", + 'pybind11', + 'cuda-python<11.7.0', + 'typeguard', + 'bfloat16', + 'typing', + 'scikit-build' + ], + cmdclass={ + 'rmm': BuildRMM + }, + ext_modules=ext_modules, + python_requires=">=3.6", +) diff --git a/tools/library/scripts/pycutlass/src/cpp/compiler.h b/tools/library/scripts/pycutlass/src/cpp/compiler.h new file mode 100644 index 00000000..220f003a --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/compiler.h @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief In-memory compiled artifact cache +*/ + +#include +#include +#include + + +namespace py = pybind11; + +namespace cutlass { + +struct CompileCache { +public: + CompileCache() = default; + ~CompileCache() = default; + + using Cache = std::unordered_map; + + /// Check if the kernel has already been compiled + py::object at(const std::string &kernel) { + auto item = cache_.find(kernel); + + if (item != cache_.end()) { + return item->second; + } + return py::none(); + } + + /// Insert a new compiled kernel for new configuration + void insert(const std::string &kernel, const py::object &compiled_kernel){ + cache_.emplace(kernel, compiled_kernel); + } + + const int64_t size() const { return cache_.size(); } + + /// Clear the cache + void clear() { cache_.clear(); } + +private: + Cache cache_; +}; + +} // namespace cutlass diff --git a/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp b/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp new file mode 100644 index 00000000..85d39238 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/cutlass.cpp @@ -0,0 +1,181 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief binding cutlass C++ APIs to python +*/ +#include +#include + +#include "builtin_types.h" +#include "device_launch_parameters.h" +#include "stddef.h" +#include "cutlass/cutlass.h" + +#include "include/conv/convolution.h" +#include "include/gemm/gemm.h" +#include "include/types.h" +#include "include/layout/layout.h" +#include "include/tensor_coord.h" +#include "include/arch.h" +#include "include/tensor_ref_view.h" +#include "include/swizzling.h" +#include "test/conv/convolution.h" +#include "test/gemm/gemm.h" + + +// Data Types +#include "library.h" + +// compiler +#include "compiler.h" + + +namespace py = pybind11; + + +PYBIND11_MODULE(cutlass, m) { + + // module doc + m.doc() = "cutlass C++ binding"; + + // + // Bind data type + // + bind_cutlass_types(m); + + // + // Bind layout + // + bind_layout(m); + + // + // Bind tensor coord + // + bind_tensor_coord(m); + + // + // Bind tensor ref + // + bind_tensor_refs_and_views(m); + + // + // Bind opcode + // + bind_opcode(m); + + // + // Bind convolution + // + py::module_ conv_submodule = m.def_submodule("conv"); + bind_convolution(conv_submodule); + + // + // Bind gemm + // + py::module_ gemm_submodule = m.def_submodule("gemm"); + bind_gemm(gemm_submodule); + + // + // Bind swizzling + // + bind_threadblock_swizzle(m); + + + // + // Bind test units + // + py::module_ test = m.def_submodule("test"); + py::module_ test_conv = test.def_submodule("conv"); + bind_convolution_test(test_conv); + + py::module_ test_gemm = test.def_submodule("gemm"); + bind_gemm_test(test_gemm); + + // data types + py::enum_(m, "dtype") + .value("b1", cutlass::DataType::kB1) + .value("u2", cutlass::DataType::kU2) + .value("u4", cutlass::DataType::kU4) + .value("u8", cutlass::DataType::kU8) + .value("u16", cutlass::DataType::kU16) + .value("u32", cutlass::DataType::kU32) + .value("u64", cutlass::DataType::kU64) + .value("s2", cutlass::DataType::kS2) + .value("s4", cutlass::DataType::kS4) + .value("s16", cutlass::DataType::kS16) + .value("s64", cutlass::DataType::kS64) + .value("cf16", cutlass::DataType::kCF16) + .value("cbf16", cutlass::DataType::kCBF16) + .value("cf32", cutlass::DataType::kCF32) + .value("ctf32", cutlass::DataType::kCTF32) + .value("cf64", cutlass::DataType::kCF64) + .value("cs2", cutlass::DataType::kCS2) + .value("cs4", cutlass::DataType::kCS4) + .value("cs8", cutlass::DataType::kCS8) + .value("cs16", cutlass::DataType::kCS16) + .value("cs32", cutlass::DataType::kCS32) + .value("cs64", cutlass::DataType::kCS64) + .value("cu2", cutlass::DataType::kCU2) + .value("cu4", cutlass::DataType::kCU4) + .value("cu8", cutlass::DataType::kCU8) + .value("cu16", cutlass::DataType::kCU16) + .value("cu32", cutlass::DataType::kCU32) + .value("cu64", cutlass::DataType::kCU64) + .value("invalid", cutlass::DataType::kInvalid); + + // layout types + py::enum_(m, "layout") + .value("ColumnMajorInterleaved2", cutlass::LayoutType::kColumnMajorInterleaved2) + .value("RowMajorInterleaved2", cutlass::LayoutType::kRowMajorInterleaved2) + .value("ColumnMajorInterleaved64", cutlass::LayoutType::kColumnMajorInterleaved64) + .value("RowMajorInterleaved64", cutlass::LayoutType::kRowMajorInterleaved64) + .value("TensorNDHWC", cutlass::LayoutType::kTensorNDHWC) + .value("TensorNCHW", cutlass::LayoutType::kTensorNCHW) + .value("TensorNGHWC", cutlass::LayoutType::kTensorNGHWC) + .value("TensorNC64HW64", cutlass::LayoutType::kTensorNC64HW64) + .value("TensorC64RSK64", cutlass::LayoutType::kTensorC64RSK64); + + // transform types + py::enum_(m, "complex_transform") + .value("none", cutlass::ComplexTransform::kNone) + .value("conj", cutlass::ComplexTransform::kConjugate); + + // + // Compiler + // + py::class_(m, "CompileCache") + .def(py::init<>()) + .def("at", &cutlass::CompileCache::at) + .def("insert", &cutlass::CompileCache::insert) + .def("size", &cutlass::CompileCache::size) + .def("clear", &cutlass::CompileCache::clear); + +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/arch.h b/tools/library/scripts/pycutlass/src/cpp/include/arch.h new file mode 100644 index 00000000..02776d4d --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/arch.h @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind opcode classes to python +*/ +#pragma once +#include +#include + +#include "cutlass/arch/mma.h" + +namespace py = pybind11; + +namespace cutlass { +enum class OpcodeClass { + kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp +}; +} + +void bind_opcode(py::module &m) { + py::enum_(m, "OpClass", + R"pbdoc(classification of math operators)pbdoc") + .value("Simt", cutlass::OpcodeClass::kSimt, + R"pbdoc(Tag classifying math operators as thread-level operations)pbdoc") + .value("TensorOp", cutlass::OpcodeClass::kTensorOp, + R"pbdoc(Tag classifing operators as Tensor Core operations)pbdoc") + .value("WmmaTensorOp", cutlass::OpcodeClass::kWmmaTensorOp, + R"pbdoc(Tag classifing operators as WMMA Tensor Core operations)pbdoc") + .value("SparseTensorOp", cutlass::OpcodeClass::kSparseTensorOp, + R"pbdoc(Tag classifing operators as sparseTensor Core operations)pbdoc"); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h b/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h new file mode 100644 index 00000000..700f7ea8 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/conv/conv_problem_size.h @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind Convolution problem sizes to python +*/ +#pragma once +#include +#include + +#include "cutlass/conv/conv2d_problem_size.h" + +namespace py = pybind11; + +void bind_conv_problem_size(py::module &m) { + // + // Conv2d Problem Size: + // include/cutlass/conv/conv2d_problem_sizd.h + // + py::class_(m, "Conv2dProblemSize") + // constructors + .def(py::init()) + .def(py::init()) + // attribute accessors + .def_readwrite("N", &cutlass::conv::Conv2dProblemSize::N) + .def_readwrite("H", &cutlass::conv::Conv2dProblemSize::H) + .def_readwrite("W", &cutlass::conv::Conv2dProblemSize::W) + .def_readwrite("C", &cutlass::conv::Conv2dProblemSize::C) + .def_readwrite("P", &cutlass::conv::Conv2dProblemSize::P) + .def_readwrite("Q", &cutlass::conv::Conv2dProblemSize::Q) + .def_readwrite("K", &cutlass::conv::Conv2dProblemSize::K) + .def_readwrite("R", &cutlass::conv::Conv2dProblemSize::R) + .def_readwrite("S", &cutlass::conv::Conv2dProblemSize::S) + .def_readwrite("pad_h", &cutlass::conv::Conv2dProblemSize::pad_h) + .def_readwrite("pad_w", &cutlass::conv::Conv2dProblemSize::pad_w) + .def_readwrite("stride_h", &cutlass::conv::Conv2dProblemSize::stride_h) + .def_readwrite("stride_w", &cutlass::conv::Conv2dProblemSize::stride_w) + .def_readwrite("dilation_h", &cutlass::conv::Conv2dProblemSize::dilation_h) + .def_readwrite("dilation_w", &cutlass::conv::Conv2dProblemSize::dilation_w) + .def_readwrite("mode", &cutlass::conv::Conv2dProblemSize::mode) + .def_readwrite("split_k_slices", &cutlass::conv::Conv2dProblemSize::split_k_slices) + .def_readwrite("groups", &cutlass::conv::Conv2dProblemSize::groups) + // functions + .def("reset_split_k_slices", &cutlass::conv::Conv2dProblemSize::reset_split_k_slices) + .def("activation_extent", &cutlass::conv::Conv2dProblemSize::activation_extent) + .def("filter_extent", &cutlass::conv::Conv2dProblemSize::filter_extent) + .def("output_extent", &cutlass::conv::Conv2dProblemSize::output_extent) + .def("activation_size", &cutlass::conv::Conv2dProblemSize::activation_size) + .def("filter_size", &cutlass::conv::Conv2dProblemSize::filter_size) + .def("output_size", &cutlass::conv::Conv2dProblemSize::output_size); + + // Get tensor size + m.def("implicit_gemm_tensor_a_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_a_size)); + m.def("implicit_gemm_tensor_b_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_b_size)); + m.def("implicit_gemm_tensor_c_size", py::overload_cast(&cutlass::conv::implicit_gemm_tensor_c_size)); + + // Get tensor extent + m.def("implicit_gemm_tensor_a_extent", + py::overload_cast< + cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& + >(&cutlass::conv::implicit_gemm_tensor_a_extent)); + + m.def("implicit_gemm_tensor_b_extent", + py::overload_cast< + cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& + >(&cutlass::conv::implicit_gemm_tensor_b_extent)); + + m.def("implicit_gemm_tensor_c_extent", + py::overload_cast< + cutlass::conv::Operator, const cutlass::conv::Conv2dProblemSize& + >(&cutlass::conv::implicit_gemm_tensor_c_extent)); + + m.def("implicit_gemm_problem_size", py::overload_cast(&cutlass::conv::implicit_gemm_problem_size)); + +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h b/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h new file mode 100644 index 00000000..53cb6272 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/conv/convolution.h @@ -0,0 +1,91 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind convolution related enum types to python +*/ +#pragma once +#include +#include + +#include "conv_problem_size.h" +#include "host.h" +#include "cutlass/conv/convolution.h" + +namespace py = pybind11; + +void bind_convolution(py::module &m) { + // + // Enumerate types + // cutlass/include/cutlass/conv/convolution.h + // + + /// Convolutional operator + py::enum_(m, "Operator", R"pbdoc(Convolutional operator)pbdoc") + .value("fprop", cutlass::conv::Operator::kFprop, "Forward propagation") + .value("dgrad", cutlass::conv::Operator::kDgrad, "Activation grad") + .value("wgrad", cutlass::conv::Operator::kWgrad, "Weight grad"); + + /// Distinguishes convolution from cross correlation + py::enum_(m, "Mode") + .value("cross_correlation", cutlass::conv::Mode::kCrossCorrelation) + .value("convolution", cutlass::conv::Mode::kConvolution); + + /// Selects among several implementation variants trading off performance with simplicity + py::enum_(m, "IteratorAlgorithm", + R"pbdoc(Selects among several implementation variants trading off performance with simplicity)pbdoc") + .value("analytic", cutlass::conv::IteratorAlgorithm::kAnalytic, R"pbdoc(functionally correct in all cases but lower performance)pbdoc") + .value("optimized", cutlass::conv::IteratorAlgorithm::kOptimized, R"pbdoc(optimized for R <= 32, S <= 32 and unity-stride dgrad)pbdoc") + .value("fixed_channels", cutlass::conv::IteratorAlgorithm::kFixedChannels, R"pbdoc(Analytic algorithm optimized for fixed channel count (C == AccessSize))pbdoc") + .value("few_channels", cutlass::conv::IteratorAlgorithm::kFewChannels, R"pbdoc(Analytic algorithm optimized for few channels (C divisible by AccessSize))pbdoc"); + + /// Distinguishes among partial specializations that accelerate certain problems where convolution + /// stride is unit. + py::enum_(m, "StrideSupport", + R"pbdoc(Distinguishes among partial specializations that accelerate certain problems where convolution + stride is unit.)pbdoc") + .value("strided", cutlass::conv::StrideSupport::kStrided, R"pbdoc(arbitrary convolution stride)pbdoc") + .value("unity", cutlass::conv::StrideSupport::kUnity, R"pbdoc(unit convolution stride)pbdoc"); + + /// Identifies split-K mode + py::enum_(m, "SplitKMode") + .value("None", cutlass::conv::SplitKMode::kNone) + .value("Serial", cutlass::conv::SplitKMode::kSerial) + .value("Parallel", cutlass::conv::SplitKMode::kParallel); + + // Conv problem sizes + bind_conv_problem_size(m); + + // + // host helper functions + // + py::module_ host_submodule = m.def_submodule("host"); + bind_conv_host_helper(host_submodule); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h b/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h new file mode 100644 index 00000000..ad4808f5 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/conv/host.h @@ -0,0 +1,54 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind conv host helpers to python +*/ +#pragma once +#include +#include + +#include "cutlass/util/host_reorder.h" +#include "cutlass/layout/tensor.h" + +namespace py = pybind11; + + +void bind_conv_host_helper(py::module &m) { + + /// reorder operand B for interleaved layout + m.def("reorder_convK", []( + cutlass::TensorRef> dest, + cutlass::TensorRef> src, + cutlass::conv::Operator conv_op, const cutlass::conv::Conv2dProblemSize & problem_size) { + cutlass::gemm::GemmCoord implicit_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_op, problem_size); + cutlass::reorder_convK<32>(dest, src, implicit_problem_size); + }); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h new file mode 100644 index 00000000..d9a93475 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/gemm/gemm.h @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind gemm related enum types to python +*/ +#pragma once +#include +#include + +#include "cutlass/gemm/gemm.h" +#include "host.h" + +namespace py = pybind11; + +void bind_gemm(py::module &m) { + // + // Enumerate types + // cutlass/gemm/gemm.h + + py::enum_(m, "Mode") + .value("Gemm", cutlass::gemm::GemmUniversalMode::kGemm, "Ordinary GEMM & GEMM Split-K serial") + .value("GemmSplitKParallel", cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, "GEMM Split-K parallel") + .value("Batched", cutlass::gemm::GemmUniversalMode::kBatched, "Batched GEMM") + .value("Array", cutlass::gemm::GemmUniversalMode::kArray) + .value("Invalid", cutlass::gemm::GemmUniversalMode::kInvalid); + + /// GemmCoord is a structure that specifies a location within the coordiate space of a GEMM problem + py::class_(m, "GemmCoord") + .def(py::init()) + .def("m", py::overload_cast<>(&cutlass::gemm::GemmCoord::m)) + .def("n", py::overload_cast<>(&cutlass::gemm::GemmCoord::n)) + .def("k", py::overload_cast<>(&cutlass::gemm::GemmCoord::k)) + // get tensor coords + .def("mk", + [](const cutlass::gemm::GemmCoord & problem_size) { + return cutlass::MatrixCoord(problem_size.mk()); + }) + .def("kn", + [](const cutlass::gemm::GemmCoord & problem_size) { + return cutlass::MatrixCoord(problem_size.kn()); + }) + .def("mn", + [](const cutlass::gemm::GemmCoord & problem_size) { + return cutlass::MatrixCoord(problem_size.mn()); + }); + + py::module_ host_submodule = m.def_submodule("host"); + bind_gemm_host_helper(host_submodule); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h b/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h new file mode 100644 index 00000000..c12d4a91 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/gemm/host.h @@ -0,0 +1,47 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind gemm host helpers to python +*/ +#pragma once +#include +#include + +#include "cutlass/util/host_reorder.h" +#include "cutlass/layout/tensor.h" + +namespace py = pybind11; + + +void bind_gemm_host_helper(py::module &m) { + m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::RowMajorInterleaved<32>>); + m.def("reorder_column", &cutlass::reorder_column<32, int8_t, cutlass::layout::ColumnMajorInterleaved<32>>); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h b/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h new file mode 100644 index 00000000..070d1b3e --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/layout/layout.h @@ -0,0 +1,47 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind CUTLASS layouts to python +*/ +#pragma once +#include +#include + +#include "tensor.h" +#include "matrix.h" + + +namespace py = pybind11; + +void bind_layout(py::module &m) { + bind_tensor_layout(m); + bind_matrix_layout(m); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h b/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h new file mode 100644 index 00000000..c55fd4fd --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/layout/matrix.h @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind Matrix layouts to python +*/ +#pragma once +#include +#include + +#include "cutlass/layout/matrix.h" + +namespace py = pybind11; + +void bind_matrix_layout(py::module &m) { + // + // Matrix layouts + // cutlass/layout/matrix.h + // + + py::class_(m, "RowMajor", R"pbdoc( + Mapping function for row-major matrices. + )pbdoc") + .def_static("packed", &cutlass::layout::RowMajor::packed, + py::arg("extent"), + R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") + .def("stride", [](const cutlass::layout::RowMajor & layout){ + return layout.stride().at(0); + }, R"pbdoc(Returns the stride of the layout)pbdoc"); + + py::class_(m, "ColumnMajor", R"pbdoc( + Mapping function for column-major matrices. + )pbdoc") + .def_static("packed", &cutlass::layout::ColumnMajor::packed, + py::arg("extent"), + R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc" ) + .def("stride", [](const cutlass::layout::ColumnMajor & layout){ + return layout.stride().at(0); + }, R"pbdoc(Returns the stride of the layout)pbdoc"); + + py::class_>(m, "RowMajorInterleaved32", + R"pbdoc(Mapping function for interleaved matrices. Matrix is structured + as row-major arrangement of fixed-size columns 32)pbdoc") + .def_static("packed", &cutlass::layout::RowMajorInterleaved<32>::packed, + py::arg("extent"), + R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") + .def("stride", [](const cutlass::layout::RowMajorInterleaved<32> & layout){ + return layout.stride().at(0); + }, R"pbdoc(Returns the stride of the layout)pbdoc"); + + py::class_>(m, "ColumnMajorInterleaved32", + R"pbdoc(Mapping function for interleaved matrices. Matrix is structured + as column-major arrangement of fixed-size rows 32)pbdoc") + .def_static("packed", &cutlass::layout::ColumnMajorInterleaved<32>::packed, + py::arg("extent"), + R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") + .def("stride", [](const cutlass::layout::ColumnMajorInterleaved<32> & layout){ + return layout.stride().at(0); + }, R"pbdoc(Returns the stride of the layout)pbdoc"); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h b/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h new file mode 100644 index 00000000..1058a14c --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/layout/tensor.h @@ -0,0 +1,74 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind Tensor layouts to python +*/ +#pragma once +#include +#include + +#include "cutlass/layout/tensor.h" + +namespace py = pybind11; + +void bind_tensor_layout(py::module &m) { + // + // Tensor layouts + // cutlass/include/cutlass/layout/tensor.h + // + + /// Mapping function for 4-D NHWC tensors. + py::class_(m, "TensorNHWC", + R"pbdoc(Mapping function for 4-D NHWC tensors)pbdoc") + .def_static("packed", &cutlass::layout::TensorNHWC::packed, + py::arg("extent"), + R"pbdoc(Helper returns a layout to a tightly packed NHWC tensor)pbdoc") + .def("stride", py::overload_cast<>(&cutlass::layout::TensorNHWC::stride), + R"pbdoc(Returns the stride of the layout)pbdoc"); + + /// Mapping function for 4-D NC/xHWx tensors. + py::class_>(m, "TensorNC32HW32", + R"pbdoc(Mapping function for 4-D NC/32HW32 tensors)pbdoc") + .def_static("packed", &cutlass::layout::TensorNCxHWx<32>::packed, + py::arg("extent"), + R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") + .def("stride", py::overload_cast<>(&cutlass::layout::TensorNCxHWx<32>::stride), + R"pbdoc(Returns the stride of the layout)pbdoc"); + + /// Mapping function for 4-D CxRSKx tensors. + py::class_>(m, "TensorC32RSK32", + R"pbdoc(Mapping function for 4-D C32RSK32 tensors)pbdoc") + .def_static("packed", &cutlass::layout::TensorCxRSKx<32>::packed, + py::arg("extent"), + R"pbdoc(Helper returns a layout to a tightly packed tensor)pbdoc") + .def("stride", py::overload_cast<>(&cutlass::layout::TensorCxRSKx<32>::stride), + R"pbdoc(Returns the stride of the layout)pbdoc"); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h b/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h new file mode 100644 index 00000000..f6ff7565 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/swizzling.h @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind threadblock swizzling to python +*/ +#pragma once +#include +#include + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/conv/threadblock/threadblock_swizzle.h" + +#include +#include + +namespace py = pybind11; + +template +void bind_identity_swizzle(py::module & m, std::string name) { + py::class_(m, name.c_str(), + R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc") + .def(py::init<>()) + .def("get_tiled_shape", + py::overload_cast( + &T::get_tiled_shape, py::const_ + ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), + R"pbdoc(Returns the shape of the problem in units of logical tiles + + :param problem_size: gemm(M, N, K) + :type problem_size: :class:`cutlass.gemm.GemmCoord` + )pbdoc") + .def("get_tiled_shape", + py::overload_cast( + &T::get_tiled_shape, py::const_ + ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), + R"pbdoc(Returns the shape of the problem in units of logical tiles + + :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) + :type problem_size: :class:`cutlass.gemm.GemmCoord`) + )pbdoc") + .def("get_tiled_shape", + py::overload_cast( + &T::get_tiled_shape, py::const_ + ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), + R"pbdoc(Returns the shape of the problem in units of logical tiles + + :param problem_size: Implicit gemm problem size conv_operator(NZPQK, NDHWC, KTRSC) + :type problem_size: :class:`cutlass.gemm.GemmCoord`) + )pbdoc") + // TODO: the returned dim3 is not usable in python + .def("get_grid_shape", &T::get_grid_shape, + py::arg("tiled_shape"), + R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") + .def("tag", [](const T & swizzle){ + return boost::core::demangle(typeid(T).name()); + }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); +} + +template +void bind_swizzle(py::module & m, std::string name, std::string doc) { + py::class_(m, name.c_str(), doc.c_str()) + .def(py::init<>()) + .def("get_tiled_shape", + py::overload_cast( + &T::get_tiled_shape, py::const_ + ), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), + R"pbdoc(Returns the shape of the problem in units of logical tiles + + :param problem_size: gemm(M, N, K) + :type problem_size: :class:`cutlass.gemm.GemmCoord` + )pbdoc") + .def("get_grid_shape", &T::get_grid_shape, + py::arg("tiled_shape"), + R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") + .def("tag", [](const T & swizzle){ + return boost::core::demangle(typeid(T).name()); + }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); +} + +template +void bind_dgrad_swizzle(py::module & m, std::string name) { + py::class_(m, name.c_str(), + R"pbdoc(Threadblock swizzling function for strided dgrad convolution)pbdoc") + .def(py::init<>()) + .def("get_tiled_shape", + py::overload_cast( + &T::get_tiled_shape, py::const_ + ), py::arg("conv_operator"), py::arg("problem_size"), py::arg("tile_size"), py::arg("split_k_slices"), + R"pbdoc(Returns the shape of the problem in units of logical tiles + + :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) + :type problem_size: :class:`cutlass.gemm.GemmCoord`) + )pbdoc") + .def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) { + return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); + }, py::arg("tiled_shape"), + R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") + .def("tag", [](const T & swizzle){ + return boost::core::demangle(typeid(T).name()); + }, R"pbdoc(Returns the c++ name of the swizzling for code emittion)pbdoc"); +} + +void bind_threadblock_swizzle(py::module &m) { + + py::class_(m, "dim3", + R"pbdoc(A int3 type xyz contains three integers)pbdoc") + .def(py::init(), + py::arg("x"), py::arg("y"), py::arg("z")) + .def_readwrite("x", &dim3::x, R"pbdoc(get value x)pbdoc") + .def_readwrite("y", &dim3::y, R"pbdoc(get value y)pbdoc") + .def_readwrite("z", &dim3::z, R"pbdoc(get value z)pbdoc"); + + bind_identity_swizzle>(m, "IdentitySwizzle1"); + bind_identity_swizzle>(m, "IdentitySwizzle2"); + bind_identity_swizzle>(m, "IdentitySwizzle4"); + bind_identity_swizzle>(m, "IdentitySwizzle8"); + + bind_swizzle(m, "HorizontalSwizzle", R"pbdoc(Threadblock swizzling function for GEMMs)pbdoc"); + bind_swizzle(m, "BatchedIdentitySwizzle", R"pbdoc(Threadblock swizzling function for batched GEMMs)pbdoc"); + + bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle1"); + bind_dgrad_swizzle>(m, "StridedDgradIdentitySwizzle4"); + bind_dgrad_swizzle(m, "StridedDgradHorizontalSwizzle"); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h b/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h new file mode 100644 index 00000000..231a21d5 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/tensor_coord.h @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind Tensor Coord to python +*/ +#pragma once +#include +#include + +#include "cutlass/tensor_coord.h" + +namespace py = pybind11; + +void bind_tensor_coord(py::module &m) { + // + // Tensor Coords + // cutlass/include/cutlass/tensor_coord.h + // + + /// Defines a canonical 4D coordinate used by tensor operations. + py::class_(m, "Tensor4DCoord", + R"pbdoc(Defines a canonical 4D coordinate used by tensor operations)pbdoc") + .def(py::init(), + py::arg("n"), py::arg("h"), py::arg("w"), py::arg("c"), + R"pbdoc(Helper to construct from N, H, W, and C)pbdoc"); + + py::class_>(m, "Tensor3DCoord", + R"pbdoc(Defines a canonical 3D coordinate used by tensor operations)pbdoc") + .def("at", py::overload_cast(&cutlass::Coord<3>::at), + py::arg("dim"), + R"pbdoc(Gets the index of a given Coord element)pbdoc"); + + // Matrix Size + py::class_(m, "MatrixCoord", + R"pbdoc(MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes + expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord.)pbdoc") + .def(py::init(), + py::arg("row"), py::arg("column"), R"pbdoc(Helper to construct from a row and column)pbdoc") + .def("row", py::overload_cast<>(&cutlass::MatrixCoord::row), + R"pbdoc(Returns the row of the coordinate)pbdoc") + .def("column", py::overload_cast<>(&cutlass::MatrixCoord::column), + R"pbdoc(Returns the column of the coordinate)pbdoc"); + +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h b/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h new file mode 100644 index 00000000..60bb19da --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/tensor_ref_view.h @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSE +#include + +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "types.h" + + +template +void bind_tensor_ref_view(py::module &m, std::string name) { + py::class_>(m, ("TensorRef" + name).c_str()) + .def("__init__", [](cutlass::TensorRef& tensor_ref, int64_t address, const L& layout_ ) { + T* ptr = reinterpret_cast< T*>(address); + new (&tensor_ref) cutlass::TensorRef(ptr, layout_); + }) + .def("data", [](cutlass::TensorRef& tensor_ref) { + T* ptr = tensor_ref.data(); + return int64_t(ptr); + }) + .def("layout", py::overload_cast<>(&cutlass::TensorRef::layout)); + + m.def("get_tensor_ref", [](int64_t address, TF data, const L& layout_) { + T* ptr = reinterpret_cast(address); + cutlass::TensorRef tensor_ref = cutlass::TensorRef(ptr, layout_); + return tensor_ref; + }); + + py::class_>(m, ("TensorView" + name).c_str()) + .def(py::init&, const typename L::TensorCoord &>()); +} + + +void bind_tensor_refs_and_views(py::module &m) { + + /// float + bind_tensor_ref_view(m, "F32RowMajor"); + bind_tensor_ref_view(m, "F32ColumnMajor"); + bind_tensor_ref_view(m, "F32NHWC"); + + /// double + bind_tensor_ref_view(m, "F64RowMajor"); + bind_tensor_ref_view(m, "F64ColumnMajor"); + bind_tensor_ref_view(m, "F64NHWC"); + + // half_t + bind_tensor_ref_view(m, "F16RowMajor"); + bind_tensor_ref_view(m, "F16ColumnMajor"); + bind_tensor_ref_view(m, "F16NHWC"); + + // bfloat16 + bind_tensor_ref_view(m, "BF16RowMajor"); + bind_tensor_ref_view(m, "BF16ColumnMajor"); + bind_tensor_ref_view(m, "BF16NHWC"); + + // int8_t + bind_tensor_ref_view, cutlass::int8>(m, "S8RowMajorInterleaved32"); + bind_tensor_ref_view, cutlass::int8>(m, "S8ColumnMajorInterleaved32"); + bind_tensor_ref_view(m, "S8RowMajor"); + bind_tensor_ref_view(m, "S8ColumnMajor"); + bind_tensor_ref_view(m, "S8NHWC"); + bind_tensor_ref_view, cutlass::int8>(m, "S8NC32HW32"); + bind_tensor_ref_view, cutlass::int8>(m, "S8C32RSK32"); + + // int32_t + bind_tensor_ref_view(m, "S32RowMajor"); + bind_tensor_ref_view(m, "S32ColumnMajor"); + bind_tensor_ref_view(m, "S32NHWC"); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/include/types.h b/tools/library/scripts/pycutlass/src/cpp/include/types.h new file mode 100644 index 00000000..ec4daa94 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/include/types.h @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind CUTLASS types to python +*/ +#pragma once +#include +#include + +#include "cutlass/half.h" + + +namespace py = pybind11; + +namespace cutlass { + +/// IEEE 32-bit signed integer +struct alignas(1) int8 { + int8_t storage; + explicit int8(int x) { + storage = int8_t(x); + } + explicit int8(float x) { + storage = int8_t(x); + } + + int8_t c_value(){return storage;} +}; + +/// IEEE 32-bit signed integer +struct alignas(4) int32 { + int storage; + explicit int32(int x) { + storage = x; + } + explicit int32(float x) { + storage = int(x); + } + + int c_value(){return storage;} +}; +/// IEEE single-precision floating-point type +struct alignas(4) float32 { + float storage; + explicit float32(float x) { + storage = x; + } + explicit float32(int x) { + storage = float(x); + } + float c_value(){return storage;} +}; +/// IEEE double-precision floating-point type +struct alignas(4) float64 { + double storage; + explicit float64(float x) { + storage = double(x); + } + explicit float64(int x) { + storage = double(x); + } + double c_value(){return storage;} +}; +} + +void bind_cutlass_types(py::module &m) { + + // s8 + py::class_(m, "int8") + .def(py::init()) + .def(py::init()) + .def_readwrite("storage", &cutlass::int8::storage) + .def("value", &cutlass::int8::c_value); + + // s32 + py::class_(m, "int32") + .def(py::init()) + .def(py::init()) + .def_readwrite("storage", &cutlass::int32::storage) + .def("value", &cutlass::int32::c_value); + + // f16 + py::class_(m, "float16") + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def_readwrite("storage", &cutlass::half_t::storage) + .def("value", [](const cutlass::half_t& value) {return value;}); + + // bf16 + py::class_(m, "bfloat16") + .def(py::init()) + .def(py::init()) + .def_readwrite("storage", &cutlass::bfloat16_t::storage) + .def("value", [](const cutlass::bfloat16_t& value) {return value;}); + + // f32 + py::class_(m, "float32") + .def(py::init()) + .def(py::init()) + .def_readwrite("storage", &cutlass::float32::storage) + .def("value", &cutlass::float32::c_value); + + // tf32 + py::class_(m, "tfloat32") + .def(py::init()) + .def(py::init()) + .def_readwrite("storage", &cutlass::tfloat32_t::storage) + .def("value", [](const cutlass::tfloat32_t& value) {return value;}); + + // f64 + py::class_(m, "float64") + .def(py::init()) + .def(py::init()) + .def_readwrite("storage", &cutlass::float64::storage) + .def("value", &cutlass::float64::c_value); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/library.h b/tools/library/scripts/pycutlass/src/cpp/library.h new file mode 100644 index 00000000..5d46f69d --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/library.h @@ -0,0 +1,32 @@ +#include + +namespace cutlass { + +/// ENUM class for datatypes +enum class DataType { + kB1, kU2, kU4, kU8, + kU16, kU32, kU64, kS2, + kS4, kS8, kS16, kS32, + kS64, kF16, kBF16, kF32, + kTF32, kF64, kCF16, kCBF16, + kCF32, kCTF32, kCF64, kCS2, + kCS4, kCS8, kCS16, kCS32, + kCS64, kCU2, kCU4, kCU8, + kCU16, kCU32, kCU64, kInvalid +}; + +/// ENUM class for LayoutTypes +enum class LayoutType { + kColumnMajor, kRowMajor, + kColumnMajorInterleaved2, kRowMajorInterleaved2, + kColumnMajorInterleaved32, kRowMajorInterleaved32, + kColumnMajorInterleaved64, kRowMajorInterleaved64, + kTensorNHWC, kTensorNDHWC, kTensorNCHW, kTensorNGHWC, + kTensorNC32HW32, kTensorNC64HW64, kTensorC32RSK32, + kTensorC64RSK64 +}; + +/// ENUM class for opcode class + + +} // namespace cutlass diff --git a/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h b/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h new file mode 100644 index 00000000..555d9900 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/test/conv/conv_problems.h @@ -0,0 +1,54 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind convolution problems to python +*/ +#pragma once +#include +#include + + +#include "unit/conv/device/conv2d_problems.h" +#include "cutlass/conv/conv2d_problem_size.h" + +namespace py = pybind11; + +PYBIND11_MAKE_OPAQUE(std::vector); + +void bind_conv_problem_size_test(py::module &m) { + + py::bind_vector>(m, "Conv2dProblemVector") + .def("size", &std::vector::size); + // Get Conv2d problem sizes + py::class_(m, "TestbedConv2dProblemSizes") + .def(py::init()) + .def_readonly("conv2d_default_sizes", &test::conv::device::TestbedConv2dProblemSizes::conv2d_default_sizes); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h b/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h new file mode 100644 index 00000000..d1efc7fd --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/test/conv/convolution.h @@ -0,0 +1,49 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind convolution related types to python +*/ +#pragma once +#include +#include + +#include "conv_problems.h" +#include "host.h" + +namespace py = pybind11; + +void bind_convolution_test(py::module &m) { + // Conv problem sizes + bind_conv_problem_size_test(m); + + py::module_ host_submodule = m.def_submodule("host"); + bind_conv_host_references(host_submodule); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h b/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h new file mode 100644 index 00000000..c4b9ef75 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/test/conv/host.h @@ -0,0 +1,180 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind Convolution host test helpers to python +*/ +#pragma once +#include +#include +#include "unit/conv/device/cache_testbed_output.h" + + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +namespace py = pybind11; + + +template +void bind_conv2d_host(py::module &m) { + m.def("conv2d", \ + &cutlass::reference::host::Conv2d< \ + Ta, La, Tb, Lb, Tc, Lc, Te, Tacc>); + + m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey); +} + +template +void bind_conv2d_host_sat(py::module &m) { + m.def("conv2d", \ + &cutlass::reference::host::Conv2d< \ + Ta, La, Tb, Lb, Tc, Lc, Te, Tacc, cutlass::NumericConverterClamp>); + + m.def("CreateCachedConv2dTestKey", &test::conv::device::CreateCachedConv2dTestKey); +} + +template +void bind_conv2d_host_nhwc(py::module &m) { + bind_conv2d_host< + Ta, cutlass::layout::TensorNHWC, + Tb, cutlass::layout::TensorNHWC, + Tc, cutlass::layout::TensorNHWC, + Tacc, Te>(m); +} + +template +void bind_conv2d_host_nc32hw32(py::module &m) { + bind_conv2d_host_sat< + Ta, cutlass::layout::TensorNCxHWx<32>, + Tb, cutlass::layout::TensorCxRSKx<32>, + Tc, cutlass::layout::TensorNCxHWx<32>, + Tacc, Te>(m); +} + + +template +void bind_tensor_equals(py::module &m) { + m.def("equals", py::overload_cast< + const cutlass::TensorView&, const cutlass::TensorView&>( + &cutlass::reference::host::TensorEquals + )); +} + +#define BIND_TENSOR_HASH(Element, Layout) { \ + m.def("TensorHash", &test::conv::device::TensorHash, py::arg("view"), py::arg("hash") = test::conv::device::CRC32(), py::arg("crc")=uint32_t()); \ +} + +void bind_conv_host_references(py::module &m) { + // + // Conv2d reference on host + // tools/util/include/cutlass/util/reference/host/convolution.h + + /// double + bind_conv2d_host_nhwc(m); + /// float + bind_conv2d_host_nhwc(m); + /// half + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + /// bfloat16 + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + /// s8 + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + bind_conv2d_host_nhwc(m); + + bind_conv2d_host_nc32hw32(m); + bind_conv2d_host_nc32hw32(m); + bind_conv2d_host_nc32hw32(m); + bind_conv2d_host_nc32hw32(m); + bind_conv2d_host_nc32hw32(m); + bind_conv2d_host_nc32hw32(m); + bind_conv2d_host_nc32hw32(m); + bind_conv2d_host_nc32hw32(m); + + // + // Compare whether two tensors are equal + // + /// double + bind_tensor_equals(m); + /// float + bind_tensor_equals(m); + /// half + bind_tensor_equals(m); + /// bfloat16 + bind_tensor_equals(m); + /// s32 + bind_tensor_equals(m); + bind_tensor_equals>(m); + /// s8 + bind_tensor_equals(m); + bind_tensor_equals>(m); + + /// Cache + py::class_(m, "CachedTestKey") + .def(py::init<>()) + .def(py::init()); + + py::class_(m, "CachedTestResult") + .def(py::init<>()) + .def(py::init()) + .def_readwrite("D", &test::conv::device::CachedTestResult::D); + + py::class_(m, "CachedTestResultListing") + .def(py::init()) + .def("find", &test::conv::device::CachedTestResultListing::find) + .def("append", &test::conv::device::CachedTestResultListing::append) + .def("write", &test::conv::device::CachedTestResultListing::write); + + py::class_(m, "CRC32") + .def(py::init<>()); + + BIND_TENSOR_HASH(double, cutlass::layout::TensorNHWC) + BIND_TENSOR_HASH(float, cutlass::layout::TensorNHWC); + BIND_TENSOR_HASH(cutlass::half_t, cutlass::layout::TensorNHWC); + BIND_TENSOR_HASH(cutlass::bfloat16_t, cutlass::layout::TensorNHWC); + BIND_TENSOR_HASH(int32_t, cutlass::layout::TensorNHWC); + BIND_TENSOR_HASH(int8_t, cutlass::layout::TensorNCxHWx<32>); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h b/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h new file mode 100644 index 00000000..5756df47 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/test/gemm/gemm.h @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind gemm test to python +*/ +#pragma once +#include +#include + +#include "host.h" + +namespace py = pybind11; + +void bind_gemm_test(py::module &m) { + py::module_ host_submodule = m.def_submodule("host"); + bind_gemm_host_reference(host_submodule); +} diff --git a/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h b/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h new file mode 100644 index 00000000..155520da --- /dev/null +++ b/tools/library/scripts/pycutlass/src/cpp/test/gemm/host.h @@ -0,0 +1,431 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Bind gemm test host functions to python +*/ +#pragma once +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/host_reorder.h" + +#include "cutlass/functional.h" + +namespace py = pybind11; + + +template< + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename AccumulatorType, typename ComputeType, + typename InnerProductOp> +void bind_host_gemm_saturate(py::module &m) { + m.def("gemm_saturate", py::overload_cast< + cutlass::gemm::GemmCoord, ComputeType, + cutlass::TensorRef, + cutlass::TensorRef, + ComputeType, + cutlass::TensorRef, + cutlass::TensorRef, + AccumulatorType>( + &cutlass::reference::host::compute_gemm< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ComputeType, + AccumulatorType, + InnerProductOp, + cutlass::NumericConverterClamp> + )); +} + +template< + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename AccumulatorType, typename ComputeType, + typename InnerProductOp> +void bind_host_gemm(py::module &m) { + m.def("gemm", py::overload_cast< + cutlass::gemm::GemmCoord, ComputeType, + cutlass::TensorRef, + cutlass::TensorRef, + ComputeType, + cutlass::TensorRef, + cutlass::TensorRef, + AccumulatorType>( + &cutlass::reference::host::compute_gemm< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ComputeType, + AccumulatorType, + InnerProductOp, + cutlass::NumericConverter> + )); +} + + +template< + typename ElementA, typename ElementB, typename ElementC, + typename AccumulatorType, typename ComputeType> +void bind_host_gemm_multiply_add(py::module &m) { + bind_host_gemm< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::RowMajor, + ComputeType, AccumulatorType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::RowMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::RowMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::RowMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); +} + +template< + typename ElementA, typename ElementB, typename ElementC, + typename AccumulatorType, typename ComputeType> +void bind_host_gemm_multiply_add_saturate(py::module &m) { + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::RowMajor, + ComputeType, AccumulatorType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::RowMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::RowMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::RowMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, cutlass::layout::ColumnMajor, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); +} + + +template< + typename ElementA, typename ElementB, typename ElementC, + typename AccumulatorType, typename ComputeType> +void bind_host_gemm_multiply_add_interleaved(py::module &m) { + bind_host_gemm< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + ComputeType, AccumulatorType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); +} + +template< + typename ElementA, typename ElementB, typename ElementC, + typename AccumulatorType, typename ComputeType> +void bind_host_gemm_multiply_add_saturate_interleaved(py::module &m) { + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + ComputeType, AccumulatorType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::RowMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::RowMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::RowMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); + + bind_host_gemm_saturate< + ElementA, cutlass::layout::ColumnMajorInterleaved<32>, + ElementB, cutlass::layout::ColumnMajorInterleaved<32>, + ElementC, cutlass::layout::ColumnMajorInterleaved<32>, + AccumulatorType, ComputeType, + cutlass::multiply_add>(m); +} + +#define BIND_TENSOR_EQUAL(Element, Layout) { \ + m.def("equals", py::overload_cast< \ + const cutlass::TensorView&, const cutlass::TensorView&>( \ + &cutlass::reference::host::TensorEquals)); \ +} + +void bind_gemm_host_reference(py::module &m) { + + /// double + bind_host_gemm_multiply_add(m); + /// float + bind_host_gemm_multiply_add(m); + /// half_t + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + /// bfloat16 + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + + /// s8 + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + bind_host_gemm_multiply_add(m); + + bind_host_gemm_multiply_add_interleaved(m); + bind_host_gemm_multiply_add_interleaved(m); + bind_host_gemm_multiply_add_interleaved(m); + bind_host_gemm_multiply_add_interleaved(m); + bind_host_gemm_multiply_add_interleaved(m); + bind_host_gemm_multiply_add_interleaved(m); + bind_host_gemm_multiply_add_interleaved(m); + bind_host_gemm_multiply_add_interleaved(m); + + bind_host_gemm_multiply_add_saturate(m); + bind_host_gemm_multiply_add_saturate(m); + bind_host_gemm_multiply_add_saturate(m); + bind_host_gemm_multiply_add_saturate(m); + bind_host_gemm_multiply_add_saturate(m); + bind_host_gemm_multiply_add_saturate(m); + bind_host_gemm_multiply_add_saturate(m); + bind_host_gemm_multiply_add_saturate(m); + + bind_host_gemm_multiply_add_saturate_interleaved(m); + bind_host_gemm_multiply_add_saturate_interleaved(m); + bind_host_gemm_multiply_add_saturate_interleaved(m); + bind_host_gemm_multiply_add_saturate_interleaved(m); + bind_host_gemm_multiply_add_saturate_interleaved(m); + bind_host_gemm_multiply_add_saturate_interleaved(m); + bind_host_gemm_multiply_add_saturate_interleaved(m); + bind_host_gemm_multiply_add_saturate_interleaved(m); + + // float + BIND_TENSOR_EQUAL(float, cutlass::layout::RowMajor); + BIND_TENSOR_EQUAL(float, cutlass::layout::ColumnMajor); + + // double + BIND_TENSOR_EQUAL(double, cutlass::layout::RowMajor); + BIND_TENSOR_EQUAL(double, cutlass::layout::ColumnMajor); + + // half_t + BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::RowMajor); + BIND_TENSOR_EQUAL(cutlass::half_t, cutlass::layout::ColumnMajor); + + // bfloat16 + BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::RowMajor); + BIND_TENSOR_EQUAL(cutlass::bfloat16_t, cutlass::layout::ColumnMajor); + + // int32_t + BIND_TENSOR_EQUAL(int32_t, cutlass::layout::RowMajor); + BIND_TENSOR_EQUAL(int32_t, cutlass::layout::ColumnMajor); + + // int8_t + BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajor); + BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajor); + BIND_TENSOR_EQUAL(int8_t, cutlass::layout::RowMajorInterleaved<32>); + BIND_TENSOR_EQUAL(int8_t, cutlass::layout::ColumnMajorInterleaved<32>); + + +} diff --git a/tools/library/scripts/pycutlass/src/pycutlass/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/__init__.py new file mode 100644 index 00000000..40a19433 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/__init__.py @@ -0,0 +1,31 @@ +from pycutlass.type import * +from pycutlass.tensor_ref import * +from pycutlass.operation import * +from pycutlass.epilogue import * +from pycutlass.compiler import ArtifactManager +from pycutlass.memory_manager import * +from pycutlass.arguments import * +from pycutlass.library import * +from pycutlass.c_types import * +from pycutlass.gemm_operation import * +from pycutlass.conv2d_operation import * +from pycutlass.compiler import * +from pycutlass.utils import * +from pycutlass.frontend import * +from pycutlass.reduction_operation import * +from pycutlass.compiler import * + +# module-wide variables + +import sys +this = sys.modules[__name__] + +# artifact manager +this.compiler = ArtifactManager() + +def get_memory_pool(init_pool_size=0, max_pool_size=2**34): + this.memory_pool = PoolMemoryManager( + init_pool_size=init_pool_size, + max_pool_size=max_pool_size + ) + return this.memory_pool diff --git a/tools/library/scripts/pycutlass/src/pycutlass/arguments.py b/tools/library/scripts/pycutlass/src/pycutlass/arguments.py new file mode 100644 index 00000000..9c6bc5d2 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/arguments.py @@ -0,0 +1,104 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from .frontend import CupyFrontend +from typeguard import typechecked +from pycutlass.frontend import * +from typing import Union +import numpy as np +from cuda import cuda +try: + import torch + torch_available = True +except ImportError: + torch_available = False +from cuda import cudart +try: + import cupy as cp + cupy_available = True +except ImportError: + cupy_available = False + + +# @typechecked +class ArgumentBase: + """ + Base class for operation arguments + """ + + def __init__(self, + A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', + B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', + C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', + D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]', + **kwargs) -> None: + + # preprocessing input tensors + if isinstance(A, np.ndarray): + self.host_D = D + self.buffer_A = NumpyFrontend.argument(A, False) + self.buffer_B = NumpyFrontend.argument(B, False) + self.buffer_C = NumpyFrontend.argument(C, False) + self.buffer_D = NumpyFrontend.argument(D, True) + self.ptr_A = self.buffer_A.ptr + self.ptr_B = self.buffer_B.ptr + self.ptr_C = self.buffer_C.ptr + self.ptr_D = self.buffer_D.ptr + elif torch_available and isinstance(A, torch.Tensor): + self.ptr_A = TorchFrontend.argument(A) + self.ptr_B = TorchFrontend.argument(B) + self.ptr_C = TorchFrontend.argument(C) + self.ptr_D = TorchFrontend.argument(D) + elif isinstance(A, cuda.CUdeviceptr): + self.ptr_A = A + self.ptr_B = B + self.ptr_C = C + self.ptr_D = D + elif cupy_available and isinstance(A, cp.ndarray): + self.ptr_A = CupyFrontend.argument(A) + self.ptr_B = CupyFrontend.argument(B) + self.ptr_C = CupyFrontend.argument(C) + self.ptr_D = CupyFrontend.argument(D) + else: + raise TypeError( + "Unsupported Frontend. Only support numpy and torch") + + def sync(self, stream_sync=True): + if stream_sync: + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + if hasattr(self, "host_D"): + err, = cuda.cuMemcpyDtoH( + self.host_D, self.ptr_D, self.host_D.size * self.host_D.itemsize) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/c_types.py b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py new file mode 100644 index 00000000..1d6abdb2 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/c_types.py @@ -0,0 +1,252 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes +from pycutlass.library import * + +# 12B + + +class GemmCoord_(ctypes.Structure): + _fields_ = [ + ("m", ctypes.c_int), + ("n", ctypes.c_int), + ("k", ctypes.c_int) + ] + + def __init__(self, gemm_coord) -> None: + for field_name, _ in self._fields_: + setattr(self, field_name, getattr(gemm_coord, field_name)()) + + +class MatrixCoord_(ctypes.Structure): + _fields_ = [ + ("row", ctypes.c_int), + ("column", ctypes.c_int) + ] + + +dtype2ctype = { + cutlass.float16: ctypes.c_uint16, + cutlass.float32: ctypes.c_float, + cutlass.float64: ctypes.c_double, + cutlass.int32: ctypes.c_int32 +} + + +def get_epilogue_output_op(element_compute_): + element_compute = dtype2ctype[element_compute_] + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", element_compute), + ("beta", element_compute), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p) + ] + return _EpilogueOutputOpParams + + +def get_gemm_arguments(element_compute_): + + _EpilogueOutputOpParams = get_epilogue_output_op(element_compute_) + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoord_), + ("batch_count", ctypes.c_int), + ("epilogue", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("batch_stride_A", ctypes.c_longlong), + ("batch_stride_B", ctypes.c_longlong), + ("batch_stride_C", ctypes.c_longlong), + ("batch_stride_D", ctypes.c_longlong), + ("stride_a", ctypes.c_longlong), + ("stride_b", ctypes.c_longlong), + ("stride_c", ctypes.c_longlong), + ("stride_d", ctypes.c_longlong), + ("lda", ctypes.c_longlong), + ("ldb", ctypes.c_longlong), + ("ldc", ctypes.c_longlong), + ("ldd", ctypes.c_longlong), + ("ptr_gather_A_indices", ctypes.c_void_p), + ("ptr_gether_B_indices", ctypes.c_void_p), + ("ptr_scatter_D_indices", ctypes.c_void_p) + ] + + return _GemmArguments, _EpilogueOutputOpParams + + +########################################################################################### +# GEMM Grouped +########################################################################################### + +# include/cutlass/gemm/kernel/gemm_grouped.h + +def get_gemm_grouped_arguments(element_compute_): + _EpilogueOutputOpParams = get_epilogue_output_op(element_compute_) + + class _GEMMGroupedArguments(ctypes.Structure): + _fields_ = [ + ("problem_sizes", ctypes.c_void_p), + ("problem_count", ctypes.c_int), + ("threadblock_count", ctypes.c_int), + ("output_op", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("lda", ctypes.c_void_p), + ("ldb", ctypes.c_void_p), + ("ldc", ctypes.c_void_p), + ("ldd", ctypes.c_void_p), + ("host_problem_sizes", ctypes.c_void_p) + ] + + return _GEMMGroupedArguments, _EpilogueOutputOpParams + +############################################################################################ +# Convolution2D +############################################################################################ + + +# We use the arguments as the interface + + +# include/cutlass/conv/conv2d_problem_size.h +# 64B +class Conv2DProblemSize(ctypes.Structure): + _fields_ = [ + ("N", ctypes.c_int), + ("H", ctypes.c_int), + ("W", ctypes.c_int), + ("C", ctypes.c_int), + ("P", ctypes.c_int), + ("Q", ctypes.c_int), + ("K", ctypes.c_int), + ("R", ctypes.c_int), + ("S", ctypes.c_int), + ("pad_h", ctypes.c_int), + ("pad_w", ctypes.c_int), + ("stride_h", ctypes.c_int), + ("stride_w", ctypes.c_int), + ("dilation_h", ctypes.c_int), + ("dilation_w", ctypes.c_int), + ("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1 + ("split_k_slices", ctypes.c_int), + ("groups", ctypes.c_int) + ] + + def __init__(self, problem_size) -> None: + for field_name, _ in self._fields_: + setattr(self, field_name, getattr(problem_size, field_name)) + + +# include/cutlass/layout/tensor.h +# 12B +class Layout4D(ctypes.Structure): + _fields_ = [ + ("stride", ctypes.c_int * 3) + ] + + def __init__(self, tensor_ref): + stride = tensor_ref.stride() + setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2))) + +# TODO: Tensor 5-D takes ("stride", ctypes.c_int * 4) + + +# include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +# TensorRef is basically cutlass::TensorRef; +# include/cutlass/tensor_ref.h +# 24B +class TensorRef_(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("layout", Layout4D) + ] + + def __init__(self, tensor_ref): + setattr(self, "ptr", tensor_ref.data()) + setattr(self, "layout", Layout4D(tensor_ref.layout())) + + +class TensorRef2D_(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("stride", ctypes.c_int) + ] + + +# include/cutlass/conv/kernel/implicit_gemm_convolution.h +# split_k_mode: kNone: 0, kSerial: 1, kParallel: 2, kParallelSerial: 3, kInvalid: 4 + +def get_conv2d_arguments(element_compute_): + _EpilogueOutputOpParams = get_epilogue_output_op(element_compute_) + + class _Conv2dArguments(ctypes.Structure): + _fields_ = [ + ("problem_size", Conv2DProblemSize), # 0 + ("ref_A", TensorRef_), # 72 + ("ref_B", TensorRef_), # 96 + ("ref_C", TensorRef_), # 120 + ("ref_D", TensorRef_), # 144 + ("output_op", _EpilogueOutputOpParams), # 168 + ("split_k_mode", ctypes.c_int) # 192 + ] + + return _Conv2dArguments, _EpilogueOutputOpParams + + +############################################################################################ +# Reduction +############################################################################################ + + +def get_reduction_params(element_compute_): + _EpilogueOutputParams = get_epilogue_output_op(element_compute_) + + class _ReductionParams(ctypes.Structure): + _fields_ = [ + ("problem_size", MatrixCoord_), + ("partitions", ctypes.c_int), + ("partition_stride", ctypes.c_longlong), + ("workspace", TensorRef2D_), + ("destination", TensorRef2D_), + ("source", TensorRef2D_), + ("output_op", _EpilogueOutputParams) + ] + return _ReductionParams, _EpilogueOutputParams diff --git a/tools/library/scripts/pycutlass/src/pycutlass/cache.py b/tools/library/scripts/pycutlass/src/pycutlass/cache.py new file mode 100644 index 00000000..322da90f --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/cache.py @@ -0,0 +1,366 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from pycutlass import * +from pycutlass.library import SubstituteTemplate +import cutlass +from cuda import cuda +from cuda import nvrtc +import tempfile +import os +import ctypes + +# +import json +import sqlite3 + + +IncludeTemplate = r'''#include "${include}" +''' + +# +class CompilationOptions: + ''' + Compilation options. + ''' + + # + def __init__(self, architectures = [80], include_paths = []): + self.includes = [] + self.include_paths = include_paths + self.flags = ['-std=c++11', '-default-device'] + self.architectures = architectures + + # + def get(self): + options = [] + + for flag in self.flags: + options.append(bytes(str.encode(flag))) + + for incl in self.include_paths: + options.append(bytes(str.encode('--include-path=%s' % incl))) + + arch_list = "-arch=" + for idx, arch in enumerate(self.architectures): + if idx: + arch_list += "," + arch_list += "sm_%d" % arch + + options.append(bytes(str.encode(arch_list))) + + return options + +def convertToBinaryData(filename): + with open(filename, 'rb') as file: + blobData = file.read() + return blobData + +def CDLLBin(host_binary): + tempfile.tempdir = "./" + temp_so = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True) + with open(temp_so.name, 'wb') as file: + file.write(host_binary) + host_lib = ctypes.CDLL(temp_so.name) + return host_lib + + +class ArtifactManager: + """ + Artifact manager + """ + def __init__(self) -> None: + try: + connection = sqlite3.connect("./compiled_cache.db") + cursor = connection.cursor() + sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)""" + cursor.execute(sqlite_create_table_query) + connection.commit() + cursor.close() + except: + pass + + self.compiled_cache_device = cutlass.CompileCache() + self.compiled_cache_host = cutlass.CompileCache() + + def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): + connection = sqlite3.connect("./compiled_cache.db") + cursor = connection.cursor() + sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)""" + + hostbin = convertToBinaryData(hostfile) + + data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs)) + + cursor.execute(sqlite_insert_blob_query, data_tuple) + connection.commit() + cursor.close() + + def load_operation(self, op_key): + connection = sqlite3.connect("./compiled_cache.db") + cursor = connection.cursor() + sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" + # try: + cursor.execute(sqlite_fetch_blob_query, (op_key, )) + record = cursor.fetchall() + if len(record) == 0: + return False + for row in record: + key, cubin_image, host_binary, operation_name, op_attr = row + op_attr = json.loads(op_attr) + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Cuda Error: {}'.format(err)) + + err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name))) + self.compiled_cache_device.insert(key, kernel) + + compiled_host_fns = {} + host_lib = CDLLBin(host_binary) + + func_name = operation_name + '_get_params' + func = getattr(host_lib, func_name) + func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0]) + compiled_host_fns['get_args'] = func + + func_name = operation_name + '_shared_memory_size' + func = getattr(host_lib, func_name) + compiled_host_fns['shared_memory_capacity'] = func() + + for attr in op_attr: + if isinstance(attr, str): + func_name = operation_name + '_' + attr + func = getattr(host_lib, func_name) + compiled_host_fns[attr] = func + + self.compiled_cache_host.insert(key, compiled_host_fns) + return True + + + def emit_compile_(self, operation_list, compilation_options): + """ + Compile a list of kernels and store them into database + """ + source_buffer_device = "" + source_buffer_host = "" + # 1. include + includes = [] + for operation in operation_list: + for incl in operation.emitter.includes: + if incl not in includes: + includes.append(incl) + + includes_host = [ + "builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes + for incl in includes: + source_buffer_device += SubstituteTemplate(IncludeTemplate, {'include': incl}) + + for incl in includes_host: + if "/device/" not in incl: + source_buffer_host += SubstituteTemplate(IncludeTemplate, { 'include': incl} ) + + + # 2. Operations + for operation in operation_list: + source_buffer_device += operation.emit() + source_buffer_host += operation.emit() + values = { + 'operation_name': operation.name(), + 'operation_suffix': operation.emitter.operation_suffix + } + source_buffer_device += SubstituteTemplate(operation.KernelTemplate, values) + source_buffer_host += SubstituteTemplate(operation.HostTemplate, values) + + # 3. compile + err, program = nvrtc.nvrtcCreateProgram( + str.encode(source_buffer_device), + bytes(str.encode("module.cu")), + 0, [], []) + + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + # Compile program + options = compilation_options.get() + + err, = nvrtc.nvrtcCompileProgram(program, len(options), options) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + + error_string = 'NVRTC Error: {}\n'.format(err) + + # Get log from compilation + err, logSize = nvrtc.nvrtcGetProgramLogSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + log = b' ' * logSize + err, = nvrtc.nvrtcGetProgramLog(program, log) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + raise RuntimeError(error_string + log.decode() + source_buffer_device) + + # Get data from compilation + err, dataSize = nvrtc.nvrtcGetCUBINSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + cubin_image = b' ' * dataSize + err, = nvrtc.nvrtcGetCUBIN(program, cubin_image) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + # compile the host code + options = compilation_options.get() + cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host + for opt in options: + opt = opt.decode("utf-8") + if opt not in ['-default-device', '-std=c++11', '-arch=sm_80']: + if '--include-path=' in opt: + cmd += " " + opt.replace('--include-path=', '-I') + else: + cmd += " "+ opt + + tempfile.tempdir = "./" + temp = tempfile.NamedTemporaryFile(prefix='host_func', suffix='.so', delete=True) + + cmd += ' - -shared -o %s' % temp.name + os.system(cmd) + host_lib = ctypes.CDLL(temp.name) + + return cubin_image, host_lib, temp + + + def add_module(self, operations, compile_options=None): + """ + Insert a new compiled device module + """ + if compile_options is None: + cutlass_path = os.getenv('CUTLASS_PATH') + assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." + cuda_install_path = os.getenv('CUDA_INSTALL_PATH') + assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." + architectures = [] + for operation in operations: + if hasattr(operation, "tile_description"): + cc = operation.tile_description.minimum_compute_capability + if cc not in architectures: + architectures.append(cc) + include_paths = [ + cuda_install_path + '/include', + cutlass_path + '/include', + cutlass_path + '/tools/util/include', + ] + compile_options = CompilationOptions(architectures, include_paths) + # save the cubin + operation_key = [] + operation_list = [] + for operation in operations: + # step 1: get kernel string as key + key = operation.rt_module.emit() + operation.procedural_name() + # step 1: check if the operation is in cache + compiled_kernel = self.compiled_cache_device.at(key) + + if compiled_kernel is None: + hit = self.load_operation(key) + if hit: + compiled_kernel = self.compiled_cache_device.at(key) + assert compiled_kernel is not None + if compiled_kernel is not None: + operation.rt_module.kernel = compiled_kernel + compiled_host_fns = self.compiled_cache_host.at(key) + assert compiled_host_fns is not None + for key in compiled_host_fns.keys(): + setattr(operation.rt_module, key, compiled_host_fns[key]) + operation.rt_module.initialize() + else: + operation_list.append(operation.rt_module) + operation_key.append(key) + if len(operation_list) > 0: + cubin_image, host_lib, host_file = self.emit_compile_(operation_list, compile_options) + + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Cuda Error: {}'.format(err)) + + operation_name = [] + operation_attr = [] + for operation, key in zip(operation_list, operation_key): + # get device kernels + err, operation.kernel = cuda.cuModuleGetFunction( + module, + bytes(str.encode(operation.name())) + ) + operation_name.append(operation.name()) + self.compiled_cache_device.insert(key, operation.kernel) + # get host functions + compiled_host_fns = {} + op_attr = [] + + # get param size + func_name = operation.name() + '_get_param_size' + func = getattr(host_lib, func_name) + param_size = func() + + func_name = operation.name() + '_get_params' + func = getattr(host_lib, func_name) + func.argtype = operation.argtype + func.restype = ctypes.POINTER(ctypes.c_char * param_size) + setattr(operation, 'get_args', func) + compiled_host_fns['get_args'] = func + + # set shared memory size + func_name = operation.name() + '_shared_memory_size' + func = getattr(host_lib, func_name) + setattr(operation, 'shared_memory_capacity', func()) + compiled_host_fns['shared_memory_capacity'] = func() + # set the maximum dynamic shared size + operation.initialize() + + # get extra functions + op_attr.append(param_size) + + if hasattr(operation, "extra_funcs"): + for suffix in operation.extra_funcs: + func_name = operation.name() + '_' + suffix + func = getattr(host_lib, func_name) + setattr(operation, suffix, func) + compiled_host_fns[suffix] = func + op_attr.append(suffix) + + operation_attr.append(op_attr) + self.compiled_cache_host.insert(key, compiled_host_fns) + + for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr): + self.insert_operation(key, cubin_image, host_file.name, operation_name, operation_attr) + + +artifact_manager = ArtifactManager() diff --git a/tools/library/scripts/pycutlass/src/pycutlass/compiler.py b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py new file mode 100644 index 00000000..158ff483 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/compiler.py @@ -0,0 +1,430 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from pycutlass import * +from pycutlass.library import SubstituteTemplate +import cutlass +from cuda import cuda +from cuda import nvrtc +import tempfile +import os +import ctypes + +# +import json +import sqlite3 + + +IncludeTemplate = r'''#include "${include}" +''' + +# + + +class CompilationOptions: + ''' + Compilation options. + ''' + + # + def __init__(self, flags, architectures=[80], include_paths=[]): + self.includes = [] + self.include_paths = include_paths + self.flags = flags + self.architectures = architectures + + def get_str(self): + options = "" + + for flag in self.flags: + options += " " + flag + + for incl in self.include_paths: + options += ' --include-path=%s' % incl + + arch_list = "-arch=" + for idx, arch in enumerate(self.architectures): + if idx: + arch_list += "," + arch_list += "sm_%d" % arch + + options += " " + arch_list + return options + + # + def get(self): + options = [] + + for flag in self.flags: + options.append(bytes(str.encode(flag))) + + for incl in self.include_paths: + options.append(bytes(str.encode('--include-path=%s' % incl))) + + arch_list = "-arch=" + for idx, arch in enumerate(self.architectures): + if idx: + arch_list += "," + arch_list += "sm_%d" % arch + + options.append(bytes(str.encode(arch_list))) + + return options + + +def convertToBinaryData(filename): + with open(filename, 'rb') as file: + blobData = file.read() + return blobData + + +def CDLLBin(host_binary): + tempfile.tempdir = "./" + temp_so = tempfile.NamedTemporaryFile( + prefix='host_func', suffix='.so', delete=True) + with open(temp_so.name, 'wb') as file: + file.write(host_binary) + host_lib = ctypes.CDLL(temp_so.name) + return host_lib + + +class ArtifactManager: + """ + Artifact manager + """ + + def __init__(self) -> None: + try: + connection = sqlite3.connect("./compiled_cache.db") + cursor = connection.cursor() + sqlite_create_table_query = """CREATE TABLE compiled_operations(op_key TEXT NOT NULL UNIQUE, cubin BLOB NOT NULL, hostbin BLOB NOT NULL, op_name TEXT NOT NULL, op_attrs TEXT NOT NULL)""" + cursor.execute(sqlite_create_table_query) + connection.commit() + cursor.close() + except: + pass + + self.backend = "nvrtc" + self.default_compile_options = [ + '-std=c++11', '-default-device', + ] + self.compiled_cache_device = cutlass.CompileCache() + self.compiled_cache_host = cutlass.CompileCache() + + def nvcc(self): + self.backend = "nvcc" + self.default_compile_options = [ + '-std=c++11', + ] + def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): + connection = sqlite3.connect("./compiled_cache.db") + cursor = connection.cursor() + sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)""" + + hostbin = convertToBinaryData(hostfile) + + data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs)) + + cursor.execute(sqlite_insert_blob_query, data_tuple) + connection.commit() + cursor.close() + + def load_operation(self, op_key): + connection = sqlite3.connect("./compiled_cache.db") + cursor = connection.cursor() + sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" + # try: + cursor.execute(sqlite_fetch_blob_query, (op_key, )) + record = cursor.fetchall() + if len(record) == 0: + return False + for row in record: + key, cubin_image, host_binary, operation_name, op_attr = row + op_attr = json.loads(op_attr) + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Cuda Error: {}'.format(err)) + + err, kernel = cuda.cuModuleGetFunction( + module, bytes(str.encode(operation_name))) + self.compiled_cache_device.insert(key, kernel) + + compiled_host_fns = {} + host_lib = CDLLBin(host_binary) + + func_name = operation_name + '_get_params' + func = getattr(host_lib, func_name) + func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0]) + compiled_host_fns['get_args'] = func + + func_name = operation_name + '_shared_memory_size' + func = getattr(host_lib, func_name) + compiled_host_fns['shared_memory_capacity'] = func() + + for attr in op_attr: + if isinstance(attr, str): + func_name = operation_name + '_' + attr + func = getattr(host_lib, func_name) + compiled_host_fns[attr] = func + + self.compiled_cache_host.insert(key, compiled_host_fns) + return True + + def emit_compile_(self, operation_list, compilation_options): + """ + Compile a list of kernels and store them into database + """ + source_buffer_device = "" + source_buffer_host = "" + # 1. include + includes = [] + for operation in operation_list: + for incl in operation.emitter.includes: + if incl not in includes: + includes.append(incl) + + includes_host = [ + "builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes + for incl in includes: + source_buffer_device += SubstituteTemplate( + IncludeTemplate, {'include': incl}) + + for incl in includes_host: + if "/device/" not in incl: + source_buffer_host += SubstituteTemplate( + IncludeTemplate, {'include': incl}) + + # 2. Operations + for operation in operation_list: + source_buffer_device += operation.emit() + source_buffer_host += operation.emit() + values = { + 'operation_name': operation.name(), + 'operation_suffix': operation.emitter.operation_suffix + } + source_buffer_device += SubstituteTemplate( + operation.KernelTemplate, values) + source_buffer_host += SubstituteTemplate( + operation.HostTemplate, values) + + if self.backend == "nvrtc": + # 3. compile + err, program = nvrtc.nvrtcCreateProgram( + str.encode(source_buffer_device), + bytes(str.encode("module.cu")), + 0, [], []) + + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + # Compile program + options = compilation_options.get() + + err, = nvrtc.nvrtcCompileProgram(program, len(options), options) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + + error_string = 'NVRTC Error: {}\n'.format(err) + + # Get log from compilation + err, logSize = nvrtc.nvrtcGetProgramLogSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + log = b' ' * logSize + err, = nvrtc.nvrtcGetProgramLog(program, log) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + raise RuntimeError( + error_string + log.decode() + source_buffer_device) + + # Get data from compilation + err, dataSize = nvrtc.nvrtcGetCUBINSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + + cubin_image = b' ' * dataSize + err, = nvrtc.nvrtcGetCUBIN(program, cubin_image) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError('NVRTC Error: {}'.format(err)) + else: # with nvcc backend + # emit code + tempfile.tempdir = "./" + temp_cu = tempfile.NamedTemporaryFile( + prefix='kernel', suffix='.cu', delete=True) + temp_cubin = tempfile.NamedTemporaryFile( + prefix='kernel', suffix='.cubin', delete=True) + with open(temp_cu.name, 'w') as file: + file.write(source_buffer_device) + + # compile with nvcc + cuda_install_path = os.getenv('CUDA_INSTALL_PATH') + assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." + cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}" + values = { + "cuda_install_path": cuda_install_path, + "options": compilation_options.get_str(), + "srcfile": temp_cu.name, + "tarfile": temp_cubin.name + } + cmd = SubstituteTemplate(cmd_template, values) + os.system(cmd) + + # load the cubin image + with open(temp_cubin.name, 'rb') as file: + cubin_image = file.read() + + # compile the host code + options = compilation_options.get() + cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host + for opt in options: + opt = opt.decode("utf-8") + if opt not in ['-default-device', '-std=c++11', '-arch=sm_80', '-Xcicc', '-Xllc']: + if '--include-path=' in opt: + cmd += " " + opt.replace('--include-path=', '-I') + else: + cmd += " " + opt + + tempfile.tempdir = "./" + temp = tempfile.NamedTemporaryFile( + prefix='host_func', suffix='.so', delete=True) + + cmd += ' - -shared -o %s' % temp.name + os.system(cmd) + host_lib = ctypes.CDLL(temp.name) + + return cubin_image, host_lib, temp + + def add_module(self, operations, compile_options=None): + """ + Insert a new compiled device module + """ + if compile_options is None: + cutlass_path = os.getenv('CUTLASS_PATH') + assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." + cuda_install_path = os.getenv('CUDA_INSTALL_PATH') + assert cuda_install_path is not None, "Environment variable 'CUDA_INSTALL_PATH' is not defined." + architectures = [] + for operation in operations: + if hasattr(operation, "tile_description"): + cc = operation.tile_description.minimum_compute_capability + if cc not in architectures: + architectures.append(cc) + include_paths = [ + cuda_install_path + '/include', + cutlass_path + '/include', + cutlass_path + '/tools/util/include', + ] + compile_options = CompilationOptions( + self.default_compile_options, architectures, include_paths) + # save the cubin + operation_key = [] + operation_list = [] + for operation in operations: + # step 1: get kernel string as key + key = operation.rt_module.emit() + operation.procedural_name() + self.backend + # step 1: check if the operation is in cache + compiled_kernel = self.compiled_cache_device.at(key) + + if compiled_kernel is None: + hit = self.load_operation(key) + if hit: + compiled_kernel = self.compiled_cache_device.at(key) + assert compiled_kernel is not None + if compiled_kernel is not None: + operation.rt_module.kernel = compiled_kernel + compiled_host_fns = self.compiled_cache_host.at(key) + assert compiled_host_fns is not None + for key in compiled_host_fns.keys(): + setattr(operation.rt_module, key, compiled_host_fns[key]) + operation.rt_module.initialize() + else: + operation_list.append(operation.rt_module) + operation_key.append(key) + if len(operation_list) > 0: + cubin_image, host_lib, host_file = self.emit_compile_( + operation_list, compile_options) + + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Cuda Error: {}'.format(err)) + + operation_name = [] + operation_attr = [] + for operation, key in zip(operation_list, operation_key): + # get device kernels + err, operation.kernel = cuda.cuModuleGetFunction( + module, + bytes(str.encode(operation.name())) + ) + operation_name.append(operation.name()) + self.compiled_cache_device.insert(key, operation.kernel) + # get host functions + compiled_host_fns = {} + op_attr = [] + + # get param size + func_name = operation.name() + '_get_param_size' + func = getattr(host_lib, func_name) + param_size = func() + + func_name = operation.name() + '_get_params' + func = getattr(host_lib, func_name) + func.argtype = operation.argtype + func.restype = ctypes.POINTER(ctypes.c_char * param_size) + setattr(operation, 'get_args', func) + compiled_host_fns['get_args'] = func + + # set shared memory size + func_name = operation.name() + '_shared_memory_size' + func = getattr(host_lib, func_name) + setattr(operation, 'shared_memory_capacity', func()) + compiled_host_fns['shared_memory_capacity'] = func() + # set the maximum dynamic shared size + operation.initialize() + + # get extra functions + op_attr.append(param_size) + + if hasattr(operation, "extra_funcs"): + for suffix in operation.extra_funcs: + func_name = operation.name() + '_' + suffix + func = getattr(host_lib, func_name) + setattr(operation, suffix, func) + compiled_host_fns[suffix] = func + op_attr.append(suffix) + + operation_attr.append(op_attr) + self.compiled_cache_host.insert(key, compiled_host_fns) + + for key, operation_name, operation_attr in zip(operation_key, operation_name, operation_attr): + self.insert_operation( + key, cubin_image, host_file.name, operation_name, operation_attr) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py new file mode 100644 index 00000000..fed535b6 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py @@ -0,0 +1,645 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +from typeguard import typechecked +from cuda import cuda +from typing import Union +import numpy as np + +from typeguard import typechecked + +from pycutlass import * + + +# @typechecked +class Conv2dArguments(ArgumentBase): + """ + Argument wrapper for Conv2d. It encodes problem information and + user-provide tensors into the kernel's argument. + + :param operation: the Conv2d operation to take the argument + :type operation: :class:`pycutlass.Conv2dOperation` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param split_k_mode: conv2d split K mode, defaults to + cutlass.conv.SplitKMode.Serial + :type split_k_mode: cutlass.conv.SplitKMode, optional + + :param output_op: output operator, optional + :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + + """ + + def __init__(self, operation: 'Conv2dOperation', + problem_size: 'cutlass.conv.Conv2dProblemSize', + A: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', + B: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', + C: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', + D: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', + split_k_mode: 'cutlass.conv.SplitKMode' + = cutlass.conv.SplitKMode.Serial, **kwargs) -> None: + + #: convolution kind + self.conv_kind: cutlass.conv.Operator = operation.conv_kind + self.layout_A: cutlass.layout = operation.A.layout + self.layout_B: cutlass.layout = operation.B.layout + self.layout_C: cutlass.layout = operation.C.layout + + self.element_A = operation.A.element + self.element_B = operation.B.element + self.element_C = operation.C.element + + if self.layout_C == cutlass.TensorNC32HW32: + B = self.reorder_tensor_B(B, problem_size) + + super().__init__(A, B, C, D, **kwargs) + # preprocessing output ops + if "output_op" in kwargs.keys() and \ + split_k_mode != cutlass.conv.SplitKMode.Parallel: + self.alpha = kwargs["output_op"].alpha + self.beta = kwargs["output_op"].beta + else: + self.alpha = 1.0 + self.beta = 0.0 + + self.element_compute = operation.element_epilogue + + if "split_k_slices" in kwargs.keys(): + self.split_k_mode = split_k_mode + self.split_k_slices = kwargs["split_k_slices"] + else: + self.split_k_mode = cutlass.conv.SplitKMode.Serial + self.split_k_slices = 1 + + #: problem_size + self.problem_size: cutlass.conv.Conv2dProblemSize = problem_size + self.problem_size.split_k_slices = self.split_k_slices + + self.operation = operation + + # + # initialize the argument + # + self.initialize() + + # @typechecked + def reorder_tensor_B(self, tensor_B: 'np.ndarray', + problem_size: 'cutlass.conv.Conv2dProblemSize'): + """ + Reorder tensor_B for interleaved layout + + :param tensor_B: input tensor B + :type tensor_B: numpy.ndarray + :param problem_size: Conv2d problem size + :type problem_size: :class:`cutlass.conv.Conv2dProblemSize` + + :return: reordered tensor B + :rtype: numpy.ndarray + """ + reordered_tensor_B = np.empty_like(tensor_B) + tensor_ref_B = self.get_tensor_ref( + tensor_B, self.element_B, self.layout_B, problem_size, "b") + reordered_tensor_ref_B = self.get_tensor_ref( + reordered_tensor_B, self.element_B, + self.layout_B, problem_size, "b") + cutlass.conv.host.reorder_convK( + reordered_tensor_ref_B, tensor_ref_B, self.conv_kind, problem_size) + + return reordered_tensor_B + + def get_tensor_ref( + self, tensor, dtype, tensor_layout, problem_size, operand): + if operand == "a": + tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent( + self.conv_kind, problem_size) + elif operand == "b": + tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent( + self.conv_kind, problem_size) + elif operand in ["c", "d"]: + tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent( + self.conv_kind, problem_size) + else: + raise ValueError("unknown operand: " + operand) + + layout = tensor_layout.packed(tensor_coord) + + return TensorRef(tensor, dtype, layout).tensor_ref + + def get_arguments(self, semaphore): + ref_A = TensorRef_(self.get_tensor_ref( + self.ptr_A, self.element_A, self.layout_A, self.problem_size, "a")) + ref_B = TensorRef_(self.get_tensor_ref( + self.ptr_B, self.element_B, self.layout_B, self.problem_size, "b")) + ref_C = TensorRef_(self.get_tensor_ref( + self.ptr_C, self.element_C, self.layout_C, self.problem_size, "c")) + ref_D = TensorRef_(self.get_tensor_ref( + self.ptr_D, self.element_C, self.layout_C, self.problem_size, "d")) + + if self.element_compute == cutlass.float16: + alpha = cutlass.float16(self.alpha).storage + beta = cutlass.float16(self.beta).storage + elif self.element_compute == cutlass.int32: + alpha = int(self.alpha) + beta = int(self.beta) + else: + alpha = self.alpha + beta = self.beta + + argument_type, epilogue_type = get_conv2d_arguments( + self.operation.element_epilogue) + + output_op = epilogue_type(alpha, beta, 0, 0) + + self.c_arguments = argument_type( + Conv2DProblemSize(self.problem_size), + ref_A, ref_B, ref_C, ref_D, output_op, self.split_k_mode + ) + + self.semaphore = semaphore + + def initialize(self): + """ + Initialize the kernel arguments handling following stuffs + 1. get kernel launch configuration including grid, cta size, + and dynamic shared memory capacity + 2. allocate and initialize device workspace + 3. get kernel params as bytearray for NVRTC input + """ + # get launch configuration + self.launch_config = self.operation.rt_module.plan(self) + + # allocate and initialize device workspace + device_workspace_size = \ + self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + # get kernel params as bytearray + semaphore = 0 + if workspace_ptr is not None and \ + self.split_k_mode == cutlass.conv.SplitKMode.Parallel: + self.ptr_D = workspace_ptr + elif workspace_ptr is not None and \ + self.split_k_mode == cutlass.conv.SplitKMode.Serial: + semaphore = workspace_ptr + + self.get_arguments(semaphore) + + params_ = self.operation.rt_module.get_args(ctypes.byref( + self.c_arguments), ctypes.c_void_p(int(self.semaphore))) + self.host_workspace = bytearray(params_.contents) + self.device_workspace = None + + def sync(self): + """ + Synchronize the arguments. If the input tensor is in host, + copy it from device to host. + """ + return super().sync() + + +# @typechecked +class Conv2dRT(ExecutableOperation): + """ + Conv2dRT manages the CUTLASS runtime components + """ + KernelTemplate = r''' +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + ''' + + HostTemplate = r''' +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Arguments* arguments, int *semaphore=nullptr){ + typename ${operation_name}${operation_suffix}::Params* params; + params = new ${operation_name}${operation_suffix}::Params(*arguments, semaphore); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) + output[i] = bytes[i]; + + return output; + } +} + + ''' + + def __init__(self, operation: 'Conv2dOperation'): + super().__init__(operation) + + self.argtype = [ctypes.POINTER(get_conv2d_arguments( + operation.element_epilogue)[0]), ctypes.c_void_p] + self.conv_kind = operation.conv_kind + + self.operation: Conv2dOperation = operation + + self.emitter = EmitConv2dInstance('_type') + + self.threads: int = operation.tile_description.num_threads + + self.swizzle_functor = operation.swizzling_functor + + def emit(self): + return self.emitter.emit(self.operation) + + # @typechecked + def get_device_workspace_size(self, arguments: Conv2dArguments): + workspace_bytes = 0 + + launch_config = arguments.launch_config + + self.conv_kind = self.operation.conv_kind + + if arguments.split_k_mode == cutlass.conv.SplitKMode.Parallel: + problem_size = arguments.problem_size + workspace_bytes = DataTypeSize[self.operation.C.element] \ + * launch_config.grid[2] * cutlass.conv.implicit_gemm_tensor_c_size( + self.conv_kind, problem_size + ) // 8 + elif arguments.split_k_mode == cutlass.conv.SplitKMode.Serial and \ + arguments.split_k_slices > 1: + workspace_bytes = launch_config.grid[0] * launch_config.grid[1] * 4 + + return workspace_bytes + + # @typechecked + def plan(self, arguments: Conv2dArguments): + tile_size = cutlass.gemm.GemmCoord( + self.operation.tile_description.threadblock_shape[0], + self.operation.tile_description.threadblock_shape[1], + self.operation.tile_description.threadblock_shape[2] + ) + + grid = self.swizzle_functor.get_grid_shape( + self.swizzle_functor.get_tiled_shape( + self.conv_kind, arguments.problem_size, + tile_size, arguments.split_k_slices + ) + ) + return LaunchConfiguration( + [grid.x, grid.y, grid.z], [self.threads, 1, 1], + self.shared_memory_capacity) + + def initialize(self): + err, = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Cuda Error: {}'.format(err)) + +# + + +class Conv2dOperation: + """ + CUTLASS Conv2d operation description. + + :param conv_kind: convolution operator + :type conv_kind: :class:`cutlass.conv.Operator` + + :param iterator_algorithm: Selects among several implementation + variants trading off performance with simplicity + :type iterator_algorithm: :class:`cutlass.conv.IteratorAlgorithm` + + :param arch: GPU compute capability (sm_xx) + :type arch: int + + :param tile_description: tile description + :type tile_description: :class:`pycutlass.TileDescription` + + :param A: tensor A description + :type A: :class:`pycutlass.TensorDescription` + + :param B: tensor B description + :type B: :class:`pycutlass.TensorDescription` + + :param C: tensor C description + :type C: :class:`pycutlass.TensorDescription` + + :param D: tensor D description + :type D: :class:`pycutlass.TensorDescription` + + :param element_epilogue: element type for computation in epilogue \ + :type element_epilogue: cutlass.int8 | cutlass.int32 | cutlass.float16 | \ + cutlass.bfloat16 | cutlass.float32 | cutlass.float64 + + :param stride_support: distinguish among partial specializations that \ + accelerate certain problems where convolution stride is unit \ + :type stride_support: :class:`cutlass.conv.StrideSupport` + + :param epilogue_functor: convolution epilogue functor + :type epilogue_functor: :class:`EpilogueFunctor` + + :param swizzling_functor: threadblock swizzling functor + """ + # + + def __init__(self, + conv_kind: cutlass.conv.Operator, + iterator_algorithm: cutlass.conv.IteratorAlgorithm, + arch: int, tile_description: TileDescription, + A: TensorDescription, B: TensorDescription, C: TensorDescription, + element_epilogue: Union[cutlass.int8, cutlass.int32, cutlass.float16, + cutlass.bfloat16, cutlass.float32, cutlass.float64], + stride_support, epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1): + + self.operation_kind: OperationKind = OperationKind.Conv2d + self.arch: int = arch + self.tile_description: TileDescription = tile_description + self.conv_kind = conv_kind + self.A: TensorDescription = A + self.B: TensorDescription = B + self.C: TensorDescription = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor() + + self.rt_module: Conv2dRT = Conv2dRT(self) + + def run(self, arguments: Conv2dArguments) -> cuda.CUresult: + """ + Launch the cuda kernel with input arguments + + :param arguments: conv2d arguments + :type arguments: :class:`pycutlass.Conv2dArguments` + """ + + # launch the kernel + err = self.rt_module.run( + arguments.host_workspace, + arguments.device_workspace, + arguments.launch_config) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('CUDA Error %s' % str(err)) + + return err + + # + # Get function name + # + + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.configuration_name() + # + + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = "%dx%d_%dx%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + self.tile_description.stages + ) + + if self.stride_support == StrideSupport.Unity: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}" + else: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" + + return SubstituteTemplate( + configuration_name, + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + intermediate_type = '' + + if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp: + inst_shape = "%d%d%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.accumulator_type(): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = '' + + return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], + inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv2dInstance: + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/conv/kernel/default_conv2d_fprop.h", + "cutlass/conv/kernel/default_conv2d_dgrad.h", + "cutlass/conv/kernel/default_conv2d_wgrad.h" + ] + self.template = """ +// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" +using ${operation_name}_base = +typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} +>::Kernel; + +struct ${operation_name}${operation_suffix}: + public ${operation_name}_base { }; + +""" + + def emit(self, operation): + + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / + operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min( + operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'swizzling_functor': operation.swizzling_functor.tag(), + 'stages': str(operation.tile_description.stages), + 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], + 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + 'stride_support': StrideSupportTag[operation.stride_support], + 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else + MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + } + + return SubstituteTemplate(self.template, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py b/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py new file mode 100644 index 00000000..2eb65797 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py @@ -0,0 +1,138 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import struct + + +def MaxAlignment(fmt): + align = 1 + for x in fmt: + align = max(align, struct.calcsize(x)) + return align + + +def AlignedOffset(offset, align): + remainder = (offset % align) + if remainder: + offset += (align - remainder) + return offset + +################################################################################################# +# +# Functors +# +################################################################################################# + +# + + +class Functor: + def __init__(self): + self.decl = '' + self.definition = '' + self.fmt = '' + self.identifier = '' + + # + def emit_declaration(self): + return self.decl + + # + def emit_definition(self): + return self.definition + + # + def size(self): + ''' + Size of the packed Params structure + ''' + return struct.calcsize(self.fmt) + + # + def alignment(self): + return MaxAlignment(self.fmt) + + # + def initialize(self, host_workspace, offset, arguments): + return offset + self.size() + +################################################################################################# + +# + + +class LinearCombinationFunctorArguments: + def __init__(self, alpha=1.0, beta=0.0): + self.alpha = alpha + self.beta = beta + self.alpha_ptr = 0 + self.beta_ptr = 0 + +# + + +class LinearCombinationFunctor(Functor): + def __init__(self): + super().__init__() + + self.decl = """ + cutlass::epilogue::thread::LinearCombination< + float, + 1, + float, + float + >""" + self.identifier = 'linear_combination' + self.fmt = "ffPP" + + # + def size(self): + ''' + Size of the packed Params structure + ''' + return struct.calcsize(self.fmt) + + # + def alignment(self): + return MaxAlignment(self.fmt) + + # + def initialize(self, host_workspace, offset, arguments): + + offset = AlignedOffset(offset, self.alignment()) + + struct.pack_into( + self.fmt, + host_workspace, offset, + arguments.alpha, arguments.beta, arguments.alpha_ptr, arguments.beta_ptr) + + return offset + self.size() diff --git a/tools/library/scripts/pycutlass/src/pycutlass/frontend.py b/tools/library/scripts/pycutlass/src/pycutlass/frontend.py new file mode 100644 index 00000000..f09a9192 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/frontend.py @@ -0,0 +1,104 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import numpy as np +from cuda import cuda +from pycutlass.memory_manager import * +from typing import TYPE_CHECKING +try: + import torch + torch_available = True +except ImportError: + torch_available = False + if TYPE_CHECKING: + import torch + +try: + import cupy as cp + cupy_available = True +except ImportError: + cupy_available = False + if TYPE_CHECKING: + import cupy as cp + + +class NumpyFrontend: + """ + Frontend node for numpy + """ + + @staticmethod + def argument(np_tensor: 'np.ndarray', is_output: 'bool') -> cuda.CUdeviceptr: + """Convert the input numpy tensor to CUDA device pointer + + :param np_tensor: input numpy nd array + :param is_output: whether the tensor is output + + :return: CUDA device pointer + """ + # copy the data to device + if is_output: + return device_mem_alloc(np_tensor.size * np_tensor.itemsize) + else: + return todevice(np_tensor) + + +class TorchFrontend: + """ + Frontend node for torch + """ + + @staticmethod + def argument(torch_tensor: 'torch.Tensor') -> cuda.CUdeviceptr: + """Convert the input torch tensor to CUDA device pointer + + :param torch_tensor: input torch tensor + :param is_output: whether the tensor is output + + :return: CUDA device pointer + """ + + # check the device of torch_tensor + if not torch_tensor.is_cuda: + torch_tensor = torch_tensor.to("cuda") + + return cuda.CUdeviceptr(torch_tensor.data_ptr()) + + +class CupyFrontend: + """ + Frontend node for cupy + """ + + @staticmethod + def argument(cupy_ndarray: 'cp.ndarray'): + return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr)) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py new file mode 100644 index 00000000..4361d7ea --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py @@ -0,0 +1,1650 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import enum +import copy +import numpy as np +from typeguard import typechecked +import cutlass +from pycutlass import * +from cuda import cuda + + +################################################################################ +# +# Data structure modeling a GEMM operation +# +################################################################################ + + +def transpose_layout(layout: cutlass.layout): + if layout == cutlass.ColumnMajor: + return cutlass.RowMajor + elif layout == cutlass.RowMajor: + return cutlass.ColumnMajor + else: + raise ValueError("unsupported Layout {}".format(layout)) + + +# @typechecked +class GemmArguments(ArgumentBase): + """ + Argument wrapper for GEMM. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`pycutlass.GemmOperationUniversal` | + :class:`pycutlass.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass.gemm.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass.gemm.Mode` + + :param output_op: output operator, optional + :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + """ + + def __init__( + self, operation: 'GemmOperation', problem_size: 'cutlass.gemm.GemmCoord', + A: 'Tensor', B: 'Tensor', C: 'Tensor', D: 'Tensor', + gemm_mode: 'cutlass.gemm.Mode'=cutlass.gemm.Mode.Gemm, **kwargs): + + self.operation = operation + + self.layout_A: cutlass.layout = operation.A.layout + self.layout_B: cutlass.layout = operation.B.layout + self.layout_C: cutlass.layout = operation.C.layout + + self.element_A = operation.A.element + self.element_B = operation.B.element + self.element_C = operation.C.element + + if (operation.C.layout in + [cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32]): + # reorder tensor B for interleaved layout output + B = self.reorder_tensor_B(B, problem_size) + + super().__init__(A, B, C, D, **kwargs) + + if operation.switched: + self.problem_size = cutlass.gemm.GemmCoord( + problem_size.n(), problem_size.m(), problem_size.k()) + self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A + else: + self.problem_size = cutlass.gemm.GemmCoord( + problem_size.m(), problem_size.n(), problem_size.k()) + + # get the leading dimension + self.lda = operation.A.layout.packed(self.problem_size.mk()).stride() + self.ldb = operation.B.layout.packed(self.problem_size.kn()).stride() + self.ldc = operation.C.layout.packed(self.problem_size.mn()).stride() + self.ldd = self.ldc + + if 'output_op' in kwargs.keys() and \ + gemm_mode != cutlass.gemm.Mode.GemmSplitKParallel: + self.alpha = kwargs['output_op'].alpha + self.beta = kwargs['output_op'].beta + else: + self.alpha = 1.0 + self.beta = 0.0 + + # get number of slices on k dimension + self.gemm_mode = gemm_mode + if 'split_k_slices' in kwargs.keys(): + self.split_k_slices = kwargs['split_k_slices'] + else: + self.split_k_slices = 1 + + self.batch_count = self.split_k_slices + + self.batched_stride_A = self.problem_size.m() * self.problem_size.k() + self.batched_stride_B = self.problem_size.n() * self.problem_size.k() + self.batched_stride_C = self.problem_size.m() * self.problem_size.n() + self.batched_stride_D = self.problem_size.m() * self.problem_size.n() + + if isinstance(self.operation, GemmOperationUniversal): + self.initialize() + + def reorder_tensor_B(self, tensor_B: 'np.ndarray', + problem_size: 'cutlass.gemm.GemmCoord'): + """ + Reorder tensor_B for interleaved layout + + :param tensor_B: input tensor B + :type tensor_B: numpy.ndarray + :param problem_size: GEMM problem size + :type problem_size: :class:`cutlass.gemm.GemmCoord` + + :return: reordered tensor B + :rtype: numpy.ndarray + """ + reordered_tensor_B = np.empty_like(tensor_B) + tensor_ref_B = self.get_tensor_ref( + tensor_B, self.element_B, self.layout_B, problem_size, "b" + ) + reordered_tensor_ref_B = self.get_tensor_ref( + reordered_tensor_B, self.element_B, self.layout_B, problem_size, "b" + ) + cutlass.gemm.host.reorder_column( + tensor_ref_B, reordered_tensor_ref_B, problem_size) + return reordered_tensor_B + + def get_tensor_ref( + self, tensor, dtype, tensor_layout, problem_size, operand): + if operand == "a": + tensor_coord = problem_size.mk() + elif operand == "b": + tensor_coord = problem_size.kn() + elif operand in ["c", "d"]: + tensor_coord = problem_size.mn() + else: + raise ValueError("unknonw operand: " + operand) + + layout = tensor_layout.packed(tensor_coord) + + return TensorRef(tensor, dtype, layout).tensor_ref + + def get_arguments(self): + problem_size_ = GemmCoord_(self.problem_size) + grid_tiled_shape_ = GemmCoord_( + cutlass.gemm.GemmCoord( + self.grid_tiled_shape.x, self.grid_tiled_shape.y, + self.grid_tiled_shape.z + ) + ) + + argument_type, epilogue_type = get_gemm_arguments( + self.operation.element_epilogue) + + if self.operation.element_epilogue == cutlass.float16: + self.alpha = cutlass.float16(self.alpha).storage + self.beta = cutlass.float16(self.beta).storage + elif self.operation.element_epilogue == cutlass.int32: + self.alpha = int(self.alpha) + self.beta = int(self.beta) + + output_op = epilogue_type(self.alpha, self.beta, 0, 0) + + arguments = argument_type( + self.gemm_mode, problem_size_, self.batch_count, output_op, + int(self.ptr_A), int(self.ptr_B), int(self.ptr_C), int(self.ptr_D), + self.batched_stride_A, self.batched_stride_B, self.batched_stride_C, + self.batched_stride_D, + self.lda, self.ldb, self.ldc, self.ldd, + self.lda, self.ldb, self.ldc, self.ldd, + 0, 0, 0 + ) + + self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size + + def initialize(self): + # get launch configuration + launch_config = self.operation.rt_module.plan(self) + + # get the host and evice workspace + device_workspace_size = \ + self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if (workspace_ptr is not None and + self.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel): + # in GEMM splik-K parallel, the D pointer is redirected + # to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif (workspace_ptr is not None and + self.gemm_mode == cutlass.gemm.Mode.Gemm): + # in GEMM split-K serial + device_workspace = workspace_ptr + + self.get_arguments() + + arguments, grid_tiled_shape, gemm_k_size = self.arguments + res_arg = self.operation.rt_module.get_args( + ctypes.byref(arguments), ctypes.byref(grid_tiled_shape), + gemm_k_size, ctypes.c_void_p(int(device_workspace))) + host_workspace = bytearray(res_arg.contents) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = launch_config + + +class GemmGroupedArguments: + """ + Argument wrapper for GEMM Grouped. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM Grouped operation to take the argument + :type operation: :class:`pycutlass.GemmOperationGrouped` + + :param problem_size: list of GEMM problem size gemm(M, N, K) + :type operation: list[:class:`cutlass.gemm.GemmCoord`] + + :param A: list of tensor A + :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param B: list of tensor B + :type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param C: list of tensor C + :type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param D: list of tensor D + :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param output_op: output operator, optional + :type output_op: :class:`pycutlass.LinearCombinationFunctorArguments` + """ + def __init__( + self, operation: 'GemmOperationGrouped', + problem_sizes: 'list[cutlass.gemm.GemmCoord]', + A: 'list[Tensor]', B: 'list[Tensor]', C: 'list[torch.Tensor]', + D: 'list[Tensor]', **kwargs): + + # get number of problems in the group + self.problem_count = len(problem_sizes) + + # check the input arguments + assert len(A) == self.problem_count + assert len(B) == self.problem_count + assert len(C) == self.problem_count + assert len(D) == self.problem_count + + problem_size_host = [] + self.ptr_A_host = [] + self.ptr_B_host = [] + self.ptr_C_host = [] + self.ptr_D_host = [] + + lda_host = [] + ldb_host = [] + ldc_host = [] + ldd_host = [] + + self.partitions = 1 + + self.operation = operation + + # get the threadblock + threadblock_shape = operation.tile_description.threadblock_shape + self.threadblock_shape = cutlass.gemm.GemmCoord( + threadblock_shape[0], threadblock_shape[1], threadblock_shape[2]) + self.threadblock_swizzle = operation.swizzling_functor + + self.total_tiles = 0 + + self.gemm_arguments = [] + + # process the input arguments + for idx, problem_size in enumerate(problem_sizes): + M, N, K = problem_size.m(), problem_size.n(), problem_size.k() + temp_argument = GemmArguments( + operation=operation, + problem_size=cutlass.gemm.GemmCoord(M, N, K), + A=A[idx], B=B[idx], C=C[idx], D=D[idx], + ) + self.gemm_arguments.append(temp_argument) + + problem_size_host.append( + [temp_argument.problem_size.m(), + temp_argument.problem_size.n(), + temp_argument.problem_size.k()] + ) + + self.ptr_A_host.append(int(temp_argument.ptr_A)) + lda_host.append(temp_argument.lda) + + self.ptr_B_host.append(int(temp_argument.ptr_B)) + ldb_host.append(temp_argument.ldb) + + self.ptr_C_host.append(int(temp_argument.ptr_C)) + ldc_host.append(temp_argument.ldc) + + self.ptr_D_host.append(int(temp_argument.ptr_D)) + ldd_host.append(temp_argument.ldd) + + # get number of tiles + grid = self.threadblock_swizzle.get_grid_shape( + self.threadblock_swizzle.get_tiled_shape( + temp_argument.problem_size, self.threadblock_shape, + temp_argument.batch_count) + ) + self.total_tiles += grid.x * grid.y * grid.z + + self.problem_size_buffer = todevice(problem_size_host, np.int32) + self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64) + self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64) + self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64) + self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64) + + self.lda_buffer = todevice(lda_host, np.int64) + self.ldb_buffer = todevice(ldb_host, np.int64) + self.ldc_buffer = todevice(ldc_host, np.int64) + self.ldd_buffer = todevice(ldd_host, np.int64) + + if 'output_op' in kwargs.keys(): + self.alpha = kwargs['output_op'].alpha + self.beta = kwargs['output_op'].beta + else: + self.alpha = 1.0 + self.beta = 0.0 + + if self.operation.element_epilogue == cutlass.float16: + self.alpha = cutlass.float16(self.alpha).storage + self.beta = cutlass.float16(self.beta).storage + elif self.operation.element_epilogue == cutlass.int32: + self.alpha = int(self.alpha) + self.beta = int(self.beta) + + # get host problem size + self.host_problem_size_ptr = np.array( + problem_size_host, dtype=np.int32).__array_interface__['data'][0] + + self.arguments = self.get_arguments() + + self.initialize() + + def get_arguments(self): + + argument_type, epilogue_type = get_gemm_grouped_arguments( + self.operation.element_epilogue) + self.output_op = epilogue_type(self.alpha, self.beta, 0, 0) + + return argument_type( + self.problem_size_buffer.ptr, self.problem_count, self.total_tiles, + self.output_op, self.ptr_A_buffer.ptr, self.ptr_B_buffer.ptr, + self.ptr_C_buffer.ptr, self.ptr_D_buffer.ptr, self.lda_buffer.ptr, + self.ldb_buffer.ptr, self.ldc_buffer.ptr, self.ldd_buffer.ptr, + ctypes.c_void_p(int(self.host_problem_size_ptr)) + ) + + def initialize(self): + # get launch configuration + launch_config = self.operation.rt_module.plan(self) + + # get the host and evice workspace + device_workspace_size = \ + self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + if self.operation.precompute_mode == SchedulerMode.Host: + device_workspace_ptr = self.operation.rt_module.host_precompute( + self, self.operation.rt_module.get_workspace_size(self)) + else: + device_workspace_ptr = 0 + + result = self.operation.rt_module.get_args( + ctypes.byref(self.arguments), self.total_tiles, + ctypes.c_void_p(int(device_workspace_ptr)) + ) + host_workspace = bytearray(result.contents) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = launch_config + + def sync(self): + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + for arg in self.gemm_arguments: + arg.sync(stream_sync=False) + + +################################################################################ +# Base class for GEMM runtime module +################################################################################ + +class GemmRTbase(ExecutableOperation): + """ + GemmRT manages the CUTLASS runtime components + """ + + KernelTemplate = r''' +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + ''' + + def __init__(self, operation: 'GemmOperation'): + super().__init__(operation) + + self.operation = operation + threadblock_shape = operation.tile_description.threadblock_shape + self.threadblock_shape = cutlass.gemm.GemmCoord( + threadblock_shape[0], threadblock_shape[1], threadblock_shape[2]) + self.threadblock_swizzle = operation.swizzling_functor + + #: number of threads per threadblock + self.threads: int = operation.tile_description.num_threads + + if (operation.epilogue_functor in + [ + EpilogueFunctor.LinearCombination, + EpilogueFunctor.FastLinearCombinationClamp, + EpilogueFunctor.LinearCombinationClamp + ]): + self.output_op = LinearCombinationFunctor() + else: + raise ValueError("unknown epilogue functor") + + # + def emit(self): + return self.emitter.emit(self.operation) + + # + def can_implement(self, configuration, arguments): + raise NotImplementedError() + + # + def get_host_workspace_size(self, arguments): + raise NotImplementedError() + + # + def get_device_workspace_size(self, arguments): + return 0 + + # + def initialize(self): + err, = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Cuda Error: {}'.format(err)) + + +################################################################################ +# Runtime module for GEMM Universal +################################################################################ + + +class GemmRTUniversal(GemmRTbase): + """ + GemmRTUniversal manages the CUTLASS runtime components + """ + HostTemplate = r''' +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, \ + cutlass::gemm::GemmCoord* grid_tiled_shape, int gemm_k_size, int* workspace){ + ${operation_name}_base::Params* params; + params = new ${operation_name}_base::Params(*argument, *grid_tiled_shape, gemm_k_size, workspace); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}_base::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) + output[i] = bytes[i]; + + return output; + } +} + ''' + + def __init__(self, operation: 'GemmOperation'): + super(GemmRTUniversal, self).__init__(operation) + self.emitter = EmitGemmUniversalInstance( + '_type', operation.direct_store) + self.argtype = [ + ctypes.POINTER(get_gemm_arguments(operation.element_epilogue)[0]), + ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p + ] + + def plan(self, arguments): + + grid = self.threadblock_swizzle.get_tiled_shape( + arguments.problem_size, self.threadblock_shape, arguments.batch_count + ) + + gemm_k_size = arguments.problem_size.k() + if (arguments.gemm_mode in + [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel]): + # + alignk = max(max(128 // DataTypeSize[self.operation.A.element], + 128 // DataTypeSize[self.operation.B.element]), 1) + + gemm_k_size = (((arguments.problem_size.k() + arguments.batch_count - 1) // + arguments.batch_count + alignk - 1) // alignk) * alignk + + if gemm_k_size: + grid_z = (arguments.problem_size.k() + + gemm_k_size - 1) // gemm_k_size + grid = cutlass.gemm.GemmCoord(grid.m(), grid.n(), grid_z) + + arguments.grid_tiled_shape = cutlass.dim3(grid.m(), grid.n(), grid.k()) + grid = self.threadblock_swizzle.get_grid_shape(grid) + arguments.gemm_k_size = gemm_k_size + return LaunchConfiguration( + [grid.x, grid.y, grid.z], + [self.threads, 1, 1], + self.shared_memory_capacity) + + # + def get_device_workspace_size(self, arguments: GemmArguments): + workspace_bytes = 0 + if arguments.gemm_mode == cutlass.gemm.Mode.GemmSplitKParallel: + workspace_bytes = (DataTypeSize[arguments.operation.C.element] + * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8) + elif (arguments.gemm_mode == cutlass.gemm.Mode.Gemm and + arguments.split_k_slices > 1): + # + workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y + + # TODO: get extra workspace size + # see https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/gemm/device/gemm_universal_base.h + return workspace_bytes + + +################################################################################################### +# Runtime module for GEMM Grouped +################################################################################################### + + +class GemmRTGrouped(GemmRTbase): + """ + GemmRTGrouped manages the CUTLASS runtime components + """ + HostTemplate = r''' + extern "C" { + + // precompute scheduling information + char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) { + char* host_workspace = new char[workspace_bytes]; + ${operation_name}_base::ProblemVisitor::host_precompute( + args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace + ); + return host_workspace; + } + + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){ + ${operation_name}_base::Params* params; + params = new ${operation_name}_base::Params(*argument, workspace, tile_count); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}_base::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) + output[i] = bytes[i]; + + return output; + } + } + ''' + + def __init__(self, operation: 'GemmOperation'): + super(GemmRTGrouped, self).__init__(operation) + self.extra_funcs = ['precompute'] + + self.emitter = EmitGemmGroupedInstance('_type') + self.argtype = [ctypes.POINTER(get_gemm_grouped_arguments( + operation.element_epilogue)[0]), ctypes.c_int, ctypes.c_void_p] + + def host_precompute(self, arguments, workspace_bytes): + self.precompute.argtype = [ + self.argtype[0], ctypes.c_int, ctypes.c_longlong] + self.precompute.restype = ctypes.POINTER( + ctypes.c_byte * workspace_bytes) + + problem_info = self.precompute(ctypes.byref( + arguments.arguments), arguments.total_tiles, workspace_bytes) + problem_info_array = bytearray(problem_info.contents) + + # copy to device memory + return rmm.DeviceBuffer.to_device(problem_info_array).ptr + + def plan(self, arguments): + return LaunchConfiguration( + [arguments.total_tiles, 1, 1], + [self.threads, 1, 1], self.shared_memory_capacity) + + def get_workspace_size(self, arguments): + if self.operation.precompute_mode == SchedulerMode.Device: + return 0 + elif self.operation.precompute_mode == SchedulerMode.Host: + total_tiles = arguments.total_tiles + entries_per_block = 1 + return 8 * entries_per_block * total_tiles # three int32_t + + +################################################################################ +# Runtime module for GEMM Grouped +################################################################################ + +# +class GemmOperationBase: + """ + CUTLASS GEMM operation + """ + # + + def __init__( + self, gemm_kind, arch, tile_description: TileDescription, + A: TensorDescription, B: TensorDescription, C: TensorDescription, + element_epilogue, epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + + #: operation kind + self.operation_kind: OperationKind = OperationKind.Gemm + #: compute capability + self.arch: int = arch + #: tile description object + self.tile_description: TileDescription = tile_description + #: gemm kind + self.gemm_kind: GemmKind = gemm_kind + + # use deep copy to avoid overwritting the original TensorDescription + if C.layout == cutlass.ColumnMajor: + #: Operand A + self.A: TensorDescription = copy.deepcopy(B) + #: Operand B + self.B: TensorDescription = copy.deepcopy(A) + #: Operand C + self.C: TensorDescription = copy.deepcopy(C) + self.A.layout = transpose_layout(self.A.layout) + self.B.layout = transpose_layout(self.B.layout) + self.C.layout = transpose_layout(self.C.layout) + self.switched = True + else: + #: Operand A + self.A: TensorDescription = copy.deepcopy(A) + #: Operand B + self.B: TensorDescription = copy.deepcopy(B) + #: Operand C + self.C: TensorDescription = copy.deepcopy(C) + self.switched = False + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor() + + if "direct_store" in kwargs: + self.direct_store = kwargs["direct_store"] + else: + self.direct_store = False + + def run(self, arguments: GemmArguments) -> cuda.CUresult: + """ + Configure and launch the cuda kernel with input arguments + """ + err = self.rt_module.run( + arguments.host_workspace, + arguments.device_workspace, + arguments.launch_config) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('CUDA Error %s' % str(err)) + + return err + + def free(self): + if hasattr(self, "workspace_buffer"): + del self.workspace_buffer + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + # + + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + } + + if self.tile_description.math_instruction.opcode_class == cutlass.OpClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == cutlass.OpClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys( + ) else '' + + inst_shape = "%d%d%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[( + self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[( + self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.A.alignment, self.B.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + + +class GemmOperationUniversal(GemmOperationBase): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description, + A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs) + self.rt_module = GemmRTUniversal(self) + + +class GemmOperationGrouped(GemmOperationBase): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description, + A, B, C, element_epilogue, epilogue_functor, swizzling_functor, **kwargs) + assert "precompute_mode" in kwargs.keys( + ), "missing keyword arguement 'precompute_mode'." + self.precompute_mode = kwargs["precompute_mode"] + self.rt_module = GemmRTGrouped(self) + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# + + +class EmitGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::Gemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + self.gemm_complex_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${transform_a}, + ${transform_b}, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // + operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min( + operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': operation.swizzling_functor.tag(), + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_complex_template if operation.is_complex() else self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + + +class EmitSparseGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // + operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min( + operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': operation.swizzling_functor.tag(), + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + + +# +class EmitGemmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix='', direct_store=False): + self.operation_suffix = operation_suffix + self.direct_store = direct_store + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + ] + if self.direct_store: + self.includes.append( + "cutlass/epilogue/threadblock/default_epilogue_direct_store.h") + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > +""" + self.builtin_epilogue_functor_template_clamp = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length} + > +""" + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_interleaved = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_direct_store = """ +// Gemm operator ${operation_name} +using ${operation_name}_default = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +using ${operation_name}_base = + cutlass::gemm::kernel::GemmUniversal< + ${operation_name}_default::Mma, + cutlass::epilogue::threadblock::DefaultEpilogueDirectStore< + ${operation_name}_default::Epilogue + >::Epilogue, + ${operation_name}_default::ThreadblockSwizzle + >; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] + for idx in range(3)] + + # transpose_layouts = { + # cutlass.layout.ColumnMajorcutlass.layout.ColumnMajor, + # cutlass.layout.RowMajorcutlass.layout.RowMajor + # } + + # if operation.A.layout in transpose_layouts.keys() and \ + # operation.B.layout in transpose_layouts.keys() and \ + # operation.C.layout in transpose_layouts.keys(): + + # instance_layout_A = transpose_layouts[operation.A.layout] + # instance_layout_B = transpose_layouts[operation.B.layout] + # instance_layout_C = transpose_layouts[operation.C.layout] + + # gemm_template = self.gemm_template + # else: + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + if self.direct_store: + gemm_template = self.gemm_template_direct_store + else: + gemm_template = self.gemm_template_interleaved + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * + DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + if operation.epilogue_functor == EpilogueFunctor.FastLinearCombinationClamp: + epilogue_functor = SubstituteTemplate( + self.builtin_epilogue_functor_template_clamp, values) + else: + epilogue_functor = SubstituteTemplate( + self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': operation.swizzling_functor.tag(), + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(gemm_template, values) + +################################################################################################### + +# + + +class EmitGemmPlanarComplexInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmKernel; + + struct ${operation_name} : + public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // + operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# + + +class EmitGemmPlanarComplexArrayInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmArrayKernel; + + struct ${operation_name} : public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // + operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# + + +class EmitGemmGroupedInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/kernel/gemm_grouped.h", + "cutlass/gemm/kernel/default_gemm_grouped.h" + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > +""" + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmGrouped< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${precompute_mode}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmGrouped<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] + for idx in range(3)] + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * + DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate( + self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': operation.swizzling_functor.tag(), + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'precompute_mode': SchedulerModeTag[operation.precompute_mode], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(self.gemm_template, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/library.py b/tools/library/scripts/pycutlass/src/pycutlass/library.py new file mode 100644 index 00000000..3ba16752 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/library.py @@ -0,0 +1,790 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import re + +################################################################################################### + +import enum +import cutlass + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + +################################################################################################### + +# + + +class GeneratorTarget(enum.Enum): + Library = enum_auto() + +# +GeneratorTargetNames = { + GeneratorTarget.Library: 'library', +} +# + +################################################################################################### + +# +ShortDataTypeNames = { + cutlass.int32: 'i', + cutlass.float16: 'h', + cutlass.float32: 's', + cutlass.float64: 'd', + cutlass.dtype.cf32: 'c', + cutlass.dtype.cf64: 'z', +} + +# +DataTypeNames = { + cutlass.dtype.b1: "b1", + cutlass.dtype.u4: "u4", + cutlass.dtype.u8: "u8", + cutlass.dtype.u16: "u16", + cutlass.dtype.u32: "u32", + cutlass.dtype.u64: "u64", + cutlass.dtype.s4: "s4", + cutlass.int8: "s8", + cutlass.dtype.s16: "s16", + cutlass.int32: "s32", + cutlass.dtype.s64: "s64", + cutlass.float16: "f16", + cutlass.bfloat16: "bf16", + cutlass.float32: "f32", + cutlass.tfloat32: "tf32", + cutlass.float64: "f64", + cutlass.dtype.cf16: "cf16", + cutlass.dtype.cbf16: "cbf16", + cutlass.dtype.cf32: "cf32", + cutlass.dtype.ctf32: "ctf32", + cutlass.dtype.cf64: "cf64", + cutlass.dtype.cu4: "cu4", + cutlass.dtype.cu8: "cu8", + cutlass.dtype.cu16: "cu16", + cutlass.dtype.cu32: "cu32", + cutlass.dtype.cu64: "cu64", + cutlass.dtype.cs4: "cs4", + cutlass.dtype.cs8: "cs8", + cutlass.dtype.cs16: "cs16", + cutlass.dtype.cs32: "cs32", + cutlass.dtype.cs64: "cs64", +} + +DataTypeTag = { + cutlass.dtype.b1: "cutlass::uint1b_t", + cutlass.dtype.u2: "cutlass::uint2b_t", + cutlass.dtype.u4: "cutlass::uint4b_t", + cutlass.dtype.u8: "uint8_t", + cutlass.dtype.u16: "uint16_t", + cutlass.dtype.u32: "uint32_t", + cutlass.dtype.u64: "uint64_t", + cutlass.dtype.s2: "cutlass::int2b_t", + cutlass.dtype.s4: "cutlass::int4b_t", + cutlass.int8: "int8_t", + cutlass.dtype.s16: "int16_t", + cutlass.int32: "int32_t", + cutlass.dtype.s64: "int64_t", + cutlass.float16: "cutlass::half_t", + cutlass.bfloat16: "cutlass::bfloat16_t", + cutlass.float32: "float", + cutlass.tfloat32: "cutlass::tfloat32_t", + cutlass.float64: "double", + cutlass.dtype.cf16: "cutlass::complex", + cutlass.dtype.cbf16: "cutlass::complex", + cutlass.dtype.cf32: "cutlass::complex", + cutlass.dtype.ctf32: "cutlass::complex", + cutlass.dtype.cf64: "cutlass::complex", + cutlass.dtype.cu2: "cutlass::complex", + cutlass.dtype.cu4: "cutlass::complex", + cutlass.dtype.cu8: "cutlass::complex", + cutlass.dtype.cu16: "cutlass::complex", + cutlass.dtype.cu32: "cutlass::complex", + cutlass.dtype.cu64: "cutlass::complex", + cutlass.dtype.cs2: "cutlass::complex", + cutlass.dtype.cs4: "cutlass::complex", + cutlass.dtype.cs8: "cutlass::complex", + cutlass.dtype.cs16: "cutlass::complex", + cutlass.dtype.cs32: "cutlass::complex", + cutlass.dtype.cs64: "cutlass::complex", +} + +DataTypeSize = { + cutlass.dtype.b1: 1, + cutlass.dtype.u4: 4, + cutlass.dtype.u8: 8, + cutlass.dtype.u16: 16, + cutlass.dtype.u32: 32, + cutlass.dtype.u64: 64, + cutlass.dtype.s4: 4, + cutlass.int8: 8, + cutlass.dtype.s16: 16, + cutlass.int32: 32, + cutlass.dtype.s64: 64, + cutlass.float16: 16, + cutlass.bfloat16: 16, + cutlass.float32: 32, + cutlass.tfloat32: 32, + cutlass.float64: 64, + cutlass.dtype.cf16: 32, + cutlass.dtype.cbf16: 32, + cutlass.dtype.cf32: 64, + cutlass.dtype.ctf32: 32, + cutlass.dtype.cf64: 128, + cutlass.dtype.cu4: 8, + cutlass.dtype.cu8: 16, + cutlass.dtype.cu16: 32, + cutlass.dtype.cu32: 64, + cutlass.dtype.cu64: 128, + cutlass.dtype.cs4: 8, + cutlass.dtype.cs8: 16, + cutlass.dtype.cs16: 32, + cutlass.dtype.cs32: 64, + cutlass.dtype.cs64: 128, +} + +################################################################################################### +# + + +class BlasMode(enum.Enum): + symmetric = enum_auto() + hermitian = enum_auto() + + +# +BlasModeTag = { + BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', + BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', +} + +# +ComplexTransformTag = { + cutlass.complex_transform.none: 'cutlass::ComplexTransform::kNone', + cutlass.complex_transform.conj: 'cutlass::ComplexTransform::kConjugate', +} + +# +RealComplexBijection = [ + (cutlass.float16, cutlass.dtype.cf16), + (cutlass.float32, cutlass.dtype.cf32), + (cutlass.float64, cutlass.dtype.cf64), +] + +# + + +def is_complex(data_type): + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + +# + + +def get_complex_from_real(real_type): + for r, c in RealComplexBijection: + if real_type == r: + return c + return cutlass.dtype.invalid + +# + + +def get_real_from_complex(complex_type): + for r, c in RealComplexBijection: + if complex_type == c: + return r + return cutlass.dtype.invalid + +# + + +class ComplexMultiplyOp(enum.Enum): + multiply_add = enum_auto() + gaussian = enum_auto() + +################################################################################################### + +# + + +class MathOperation(enum.Enum): + multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + xor_popc = enum_auto() + multiply_add_fast_bf16 = enum_auto() + multiply_add_fast_f16 = enum_auto() + multiply_add_fast_f32 = enum_auto() + multiply_add_complex_fast_f32 = enum_auto() + multiply_add_complex = enum_auto() + multiply_add_complex_gaussian = enum_auto() + + +# +MathOperationNames = { + MathOperation.multiply_add: 'multiply_add', + MathOperation.multiply_add_saturate: 'multiply_add_saturate', + MathOperation.xor_popc: 'xor_popc', + MathOperation.multiply_add_fast_bf16: 'multiply_add_fast_bf16', + MathOperation.multiply_add_fast_f16: 'multiply_add_fast_f16', + MathOperation.multiply_add_fast_f32: 'multiply_add_fast_f32', + MathOperation.multiply_add_complex_fast_f32: 'multiply_add_complex_fast_f32', + MathOperation.multiply_add_complex: 'multiply_add_complex', + MathOperation.multiply_add_complex_gaussian: 'multiply_add_complex_gaussian', +} + +# +MathOperationTag = { + MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', + MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', + MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', + MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', + MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', + MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', + MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', + MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', +} + +################################################################################################### + +# +LayoutTag = { + cutlass.ColumnMajor: 'cutlass::layout::ColumnMajor', + cutlass.RowMajor: 'cutlass::layout::RowMajor', + cutlass.layout.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', + cutlass.layout.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', + cutlass.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', + cutlass.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', + cutlass.layout.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', + cutlass.layout.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', + cutlass.TensorNHWC: 'cutlass::layout::TensorNHWC', + cutlass.layout.TensorNDHWC: 'cutlass::layout::TensorNDHWC', + cutlass.layout.TensorNCHW: 'cutlass::layout::TensorNCHW', + cutlass.layout.TensorNGHWC: 'cutlass::layout::TensorNGHWC', + cutlass.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', + cutlass.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', + cutlass.layout.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', + cutlass.layout.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', +} + +# +TransposedLayout = { + cutlass.ColumnMajor: cutlass.RowMajor, + cutlass.RowMajor: cutlass.ColumnMajor, + cutlass.layout.ColumnMajorInterleaved2: cutlass.layout.RowMajorInterleaved2, + cutlass.layout.RowMajorInterleaved2: cutlass.layout.ColumnMajorInterleaved2, + cutlass.ColumnMajorInterleaved32: cutlass.RowMajorInterleaved32, + cutlass.RowMajorInterleaved32: cutlass.ColumnMajorInterleaved32, + cutlass.layout.ColumnMajorInterleaved64: cutlass.layout.RowMajorInterleaved64, + cutlass.layout.RowMajorInterleaved64: cutlass.layout.ColumnMajorInterleaved64, + cutlass.TensorNHWC: cutlass.TensorNHWC +} + +# +ShortLayoutTypeNames = { + cutlass.ColumnMajor: 'n', + cutlass.layout.ColumnMajorInterleaved2: 'n2', + cutlass.ColumnMajorInterleaved32: 'n32', + cutlass.layout.ColumnMajorInterleaved64: 'n64', + cutlass.RowMajor: 't', + cutlass.layout.RowMajorInterleaved2: 't2', + cutlass.RowMajorInterleaved32: 't32', + cutlass.layout.RowMajorInterleaved64: 't64', + cutlass.TensorNHWC: 'nhwc', + cutlass.layout.TensorNDHWC: 'ndhwc', + cutlass.layout.TensorNCHW: 'nchw', + cutlass.layout.TensorNGHWC: 'nghwc', + cutlass.TensorNC32HW32: 'nc32hw32', + cutlass.layout.TensorNC64HW64: 'nc64hw64', + cutlass.TensorC32RSK32: 'c32rsk32', + cutlass.layout.TensorC64RSK64: 'c64rsk64' +} + +# +ShortComplexLayoutNames = { + (cutlass.ColumnMajor, cutlass.complex_transform.none): 'n', + (cutlass.ColumnMajor, cutlass.complex_transform.conj): 'c', + (cutlass.RowMajor, cutlass.complex_transform.none): 't', + (cutlass.RowMajor, cutlass.complex_transform.conj): 'h' +} + +################################################################################################### + +# + + +class SideMode(enum.Enum): + Left = enum_auto() + Right = enum_auto() + + +# +SideModeTag = { + SideMode.Left: 'cutlass::SideMode::kLeft', + SideMode.Right: 'cutlass::SideMode::kRight' +} + +# +ShortSideModeNames = { + SideMode.Left: 'ls', + SideMode.Right: 'rs' +} + +################################################################################################### + +# + + +class FillMode(enum.Enum): + Lower = enum_auto() + Upper = enum_auto() + + +# +FillModeTag = { + FillMode.Lower: 'cutlass::FillMode::kLower', + FillMode.Upper: 'cutlass::FillMode::kUpper' +} + +# +ShortFillModeNames = { + FillMode.Lower: 'l', + FillMode.Upper: 'u' +} + +################################################################################################### + +# + + +class DiagType(enum.Enum): + NonUnit = enum_auto() + Unit = enum_auto() + + +# +DiagTypeTag = { + DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', + DiagType.Unit: 'cutlass::DiagType::kUnit' +} + +# +ShortDiagTypeNames = { + DiagType.NonUnit: 'nu', + DiagType.Unit: 'un' +} + +################################################################################################### + +OpcodeClassNames = { + cutlass.OpClass.Simt: 'simt', + cutlass.OpClass.TensorOp: 'tensorop', + cutlass.OpClass.WmmaTensorOp: 'wmma_tensorop', + cutlass.OpClass.SparseTensorOp: 'sptensorop' +} + +OpcodeClassTag = { + cutlass.OpClass.Simt: 'cutlass::arch::OpClassSimt', + cutlass.OpClass.TensorOp: 'cutlass::arch::OpClassTensorOp', + cutlass.OpClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', + cutlass.OpClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp' +} + +################################################################################################### + +# + + +class OperationKind(enum.Enum): + Gemm = enum_auto() + RankK = enum_auto() + Rank2K = enum_auto() + Trmm = enum_auto() + Symm = enum_auto() + Conv2d = enum_auto() + Conv3d = enum_auto() + + +# +OperationKindNames = { + OperationKind.Gemm: 'gemm', OperationKind.RankK: 'rank_k', OperationKind.Rank2K: 'rank_2k', OperationKind.Trmm: 'trmm', OperationKind.Symm: 'symm', OperationKind.Conv2d: 'conv2d', OperationKind.Conv3d: 'conv3d' +} + +# +ArchitectureNames = { + 50: 'maxwell', + 60: 'pascal', + 61: 'pascal', + 70: 'volta', + 75: 'turing', + 80: 'ampere', +} + +# +SharedMemPerCC = { + 70: 96, # 96KB of SMEM + 72: 96, # 96KB of SMEM + 75: 64, # 64KB of SMEM + 80: 160, # 164KB of SMEM - 4KB reserved for the driver + 86: 100, # 100KB of SMEM + 87: 160, # 164KB of SMEM - 4KB reserved for the driver +} + +################################################################################################### + +# + + +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + +################################################################################################### + +# + + +class GemmKind(enum.Enum): + Gemm = enum_auto() + Sparse = enum_auto() + Universal = enum_auto() + PlanarComplex = enum_auto() + PlanarComplexArray = enum_auto() + Grouped = enum_auto() + + +# +GemmKindNames = { + GemmKind.Gemm: "gemm", + GemmKind.Sparse: "spgemm", + GemmKind.Universal: "gemm", + GemmKind.PlanarComplex: "gemm_planar_complex", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", + GemmKind.Grouped: "gemm_grouped" +} + +# + + +class RankKKind(enum.Enum): + Universal = enum_auto() + + +# +RankKKindNames = { + RankKKind.Universal: "rank_k" +} + +# + + +class TrmmKind(enum.Enum): + Universal = enum_auto() + + +# +TrmmKindNames = { + TrmmKind.Universal: "trmm" +} + +# + + +class SymmKind(enum.Enum): + Universal = enum_auto() + + +# +SymmKindNames = { + SymmKind.Universal: "symm" +} + +# + + +class EpilogueFunctor(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationClamp = enum_auto() + FastLinearCombinationClamp = enum_auto() + + +# +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', + EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', + EpilogueFunctor.FastLinearCombinationClamp: 'cutlass::epilogue::thread::FastLinearCombinationClamp' +} + +# + + +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + Horizontal = enum_auto() + BatchedIdentity1 = enum_auto() + StridedDgradIdentity1 = enum_auto() + StridedDgradIdentity4 = enum_auto() + StridedDgradHorizontal = enum_auto() + + +# +SwizzlingFunctorTag = { + cutlass.IdentitySwizzle1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', + SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', + SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', + SwizzlingFunctor.BatchedIdentity1: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle", + SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', +} + +# + + +class SchedulerMode(enum.Enum): + Device = enum_auto(), + Host = enum_auto() + + +# +SchedulerModeTag = { + SchedulerMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', + SchedulerMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' +} + +# +ShortSchedulerModeNames = { + SchedulerMode.Device: 'Device', + SchedulerMode.Host: 'Host' +} + +################################################################################################### + + +# +ConvKindTag = { + cutlass.conv.Operator.fprop: 'cutlass::conv::Operator::kFprop', + cutlass.conv.Operator.dgrad: 'cutlass::conv::Operator::kDgrad', + cutlass.conv.Operator.wgrad: 'cutlass::conv::Operator::kWgrad' +} + +ConvKindNames = { + cutlass.conv.Operator.fprop: 'fprop', + cutlass.conv.Operator.dgrad: 'dgrad', + cutlass.conv.Operator.wgrad: 'wgrad', +} + + +# +IteratorAlgorithmTag = { + cutlass.conv.IteratorAlgorithm.analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', + cutlass.conv.IteratorAlgorithm.optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', + cutlass.conv.IteratorAlgorithm.fixed_channels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', + cutlass.conv.IteratorAlgorithm.few_channels: 'cutlass::conv::IteratorAlgorithm::kFewChannels' +} + +IteratorAlgorithmNames = { + cutlass.conv.IteratorAlgorithm.analytic: 'analytic', + cutlass.conv.IteratorAlgorithm.optimized: 'optimized', + cutlass.conv.IteratorAlgorithm.fixed_channels: 'fixed_channels', + cutlass.conv.IteratorAlgorithm.few_channels: 'few_channels' +} + +# + + +class StrideSupport(enum.Enum): + Strided = enum_auto() + Unity = enum_auto() + + +# +StrideSupportTag = { + StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', + StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', +} + +StrideSupportNames = { + StrideSupport.Strided: '', + StrideSupport.Unity: 'unity_stride', +} + + +class ConvMode(enum.Enum): + CrossCorrelation = enum_auto() + Convolution = enum_auto() + + +# +ConvModeTag = { + ConvMode.CrossCorrelation: 'cutlass::conv::Mode::kCrossCorrelation', + ConvMode.Convolution: 'cutlass::conv::Mode::kConvolution' +} + +################################################################################################### + +# + + +class MathInstruction: + def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class=cutlass.OpClass.Simt, math_operation=MathOperation.multiply_add): + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + +# + + +class TileDescription: + + def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = threadblock_shape + + #: number of pipeline stages + self.stages: int = stages + + #: number of warps along x, y, z directions + self.warp_count: list[int] = warp_count + self.math_instruction = math_instruction + + #: minimum compute capability + self.minimum_compute_capability: int = min_compute + #: maximum compute capability + self.maximum_compute_capability: int = max_compute + + #: number threads per threadblock + self.num_threads: int = 32 + for cnt in self.warp_count: + self.num_threads *= cnt + + def procedural_name(self): + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + +# + + +class TensorDescription: + def __init__(self, element, layout, alignment=1, complex_transform=cutlass.complex_transform.none): + self.element = element + self.layout = layout + self.alignment = min(128 // DataTypeSize[self.element], alignment) + self.complex_transform = complex_transform + +# + + +class SymmetricTensorDescription: + def __init__(self, element, layout, fill_mode, alignment=1, complex_transform=cutlass.complex_transform.none, side_mode=SideMode.Left): + self.element = element + self.layout = layout + self.fill_mode = fill_mode + self.alignment = alignment + self.complex_transform = complex_transform + self.side_mode = side_mode + +# + + +class TriangularTensorDescription: + def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment=1, complex_transform=cutlass.complex_transform.none): + self.element = element + self.layout = layout + self.side_mode = side_mode + self.fill_mode = fill_mode + self.diag_type = diag_type + self.alignment = alignment + self.complex_transform = complex_transform + +################################################################################################### + +# + + +def CalculateSmemUsage(operation): + cta_shape = operation.tile_description.threadblock_shape + stages = operation.tile_description.stages + + if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: + # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) + if DataTypeSize[operation.A.element] == 32: + elements_per_8b_md = 2 + elif DataTypeSize[operation.A.element] == 4: + elements_per_8b_md = 8 + else: + elements_per_8b_md = 4 + + smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md + else: + # Few BLAS3 operations only have A tensor + smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \ + DataTypeSize[operation.A.element] * \ + cta_shape[1] * cta_shape[2] // 8 + + smem_usage = smem_per_stage * stages + return (smem_usage >> 10) +################################################################################################### diff --git a/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py b/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py new file mode 100644 index 00000000..89c97bad --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py @@ -0,0 +1,74 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import rmm +import numpy as np + + +class PoolMemoryManager: + def __init__(self, init_pool_size: int, max_pool_size: int) -> None: + self.pool = rmm.mr.PoolMemoryResource( + rmm.mr.CudaMemoryResource(), + initial_pool_size=init_pool_size, + maximum_pool_size=max_pool_size + ) + self.mr = rmm.mr.TrackingResourceAdaptor(self.pool) + rmm.mr.set_current_device_resource(self.mr) + + def get_allocated_size(self): + return self.mr.get_allocated_bytes() + + def pool_size(self): + return self.pool.pool_size() + + +def todevice(host_data, dtype=np.float32): + """ + Pass the host_data to device memory + """ + if isinstance(host_data, list): + return rmm.DeviceBuffer.to_device(np.array(host_data, dtype=dtype).tobytes()) + elif isinstance(host_data, np.ndarray): + return rmm.DeviceBuffer.to_device(host_data.tobytes()) + + +def device_mem_alloc(size): + return rmm.DeviceBuffer(size=size) + + +def align_size(size, alignment=256): + return ((size + alignment - 1) // alignment) * alignment + + +def get_allocated_size(): + device_resource = rmm.mr.get_current_device_resource() + return device_resource.get_allocated_bytes() diff --git a/tools/library/scripts/pycutlass/src/pycutlass/operation.py b/tools/library/scripts/pycutlass/src/pycutlass/operation.py new file mode 100644 index 00000000..b1721b03 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/operation.py @@ -0,0 +1,110 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import ctypes +from cuda import cuda + +################################################################################ +# +# Launch configuration +# +################################################################################ + + +class LaunchConfiguration: + def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0): + self.grid = grid + self.block = block + self.shared_memory_capacity = smem + + +################################################################################ +# +# Base class for an executable operation +# +# ############################################################################## + +class ExecutableOperation: + ''' + ''' + + def __init__(self, operation): + self.operation = operation + self.module = None + self.kernel = None + + # + def name(self): + return self.operation.procedural_name() + + # + def emit(self): + return '' + + # + def can_implement(self, configuration, arguments): + raise NotImplementedError() + + # + def get_host_workspace_size(self, arguments): + raise NotImplementedError() + + # + def get_device_workspace_size(self, arguments): + raise NotImplementedError() + + # + def plan(self, arguments): + raise NotImplementedError() + + # + def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=cuda.CUstream(0)): + raise NotImplementedError() + + # + def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstream(0)): + + cArg = (ctypes.c_char * len(host_workspace) + ).from_buffer(host_workspace) + packed = (ctypes.c_void_p * 1)() + packed[0] = ctypes.addressof(cArg) + + err, = cuda.cuLaunchKernel( + self.kernel, + launch_config.grid[0], launch_config.grid[1], launch_config.grid[2], + launch_config.block[0], launch_config.block[1], launch_config.block[2], + launch_config.shared_memory_capacity, + stream, + packed, + 0) + + return err diff --git a/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py b/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py new file mode 100644 index 00000000..a5f7217a --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py @@ -0,0 +1,402 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +from pycutlass import * +from pycutlass.c_types import get_reduction_params +import cutlass +from cuda import cuda +try: + import torch + torch_available = True +except ImportError: + torch_available = False +import numpy as np +from typing import Union +from cuda import cudart + + +class ReductionOperation: + pass + + +class ReductionArguments: + """ + Arguments of reduction + """ + + def __init__(self, operation: ReductionOperation, + problem_size: 'list[int]', partitions: int, + workspace: cuda.CUdeviceptr, + destination: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', + source: 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]', **kwargs) -> None: + + self.operation = operation + #: pointer to the workspace + self.ptr_workspace = workspace + + #: number of split-k partitions + self.partitions = partitions + + if isinstance(destination, np.ndarray): + self.host_D = destination + self.destination_buffer = NumpyFrontend.argument(destination, True) + self.source_buffer = NumpyFrontend.argument(source, False) + self.ptr_destination = cuda.CUdeviceptr( + self.destination_buffer.ptr) + self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr) + elif torch_available and isinstance(destination, torch.Tensor): + self.ptr_destination = TorchFrontend.argument(destination) + self.ptr_source = TorchFrontend.argument(source) + elif isinstance(destination, cuda.CUdeviceptr): + self.ptr_destination = destination + self.ptr_source = source + else: + raise TypeError("unknown Type") + + self.problem_size = MatrixCoord_( + problem_size[0], problem_size[1] + ) + + self.partition_stride = problem_size[0] * \ + problem_size[1] * DataTypeSize[operation.C.element] // 8 + + if "output_op" in kwargs.keys(): + self.alpha = kwargs["output_op"].alpha + self.beta = kwargs["output_op"].beta + else: + self.alpha = 1.0 + self.beta = 0.0 + + # get arguments + self.get_arguments() + + @staticmethod + def get_tensor_ref(extent: 'tuple[int]', device_ptr: cuda.CUdeviceptr, layout: cutlass.layout): + if layout == cutlass.RowMajor: + return TensorRef2D_(int(device_ptr), extent[1]) + else: + raise ValueError("unknonwn layout type") + + def get_arguments(self): + ref_workspace = ReductionArguments.get_tensor_ref( + extent=[self.problem_size.row, self.problem_size.column], + device_ptr=self.ptr_workspace, layout=cutlass.RowMajor) + + ref_source = ReductionArguments.get_tensor_ref( + extent=[self.problem_size.row, self.problem_size.column], + device_ptr=self.ptr_source, layout=cutlass.RowMajor) + + ref_destination = ReductionArguments.get_tensor_ref( + extent=[self.problem_size.row, self.problem_size.column], + device_ptr=self.ptr_destination, layout=cutlass.RowMajor) + + argument_type, epilogue_type = get_reduction_params( + self.operation.element_compute) + + if self.operation.element_compute == cutlass.float16: + self.alpha = cutlass.float16(self.alpha).storage + self.beta = cutlass.float16(self.beta).storage + elif self.operation.element_compute == cutlass.int32: + self.alpha = int(self.alpha) + self.beta = int(self.beta) + + output_op = epilogue_type(self.alpha, self.beta, 0, 0) + self.c_arguments = argument_type( + self.problem_size, self.partitions, + self.partition_stride, ref_workspace, + ref_destination, ref_source, + output_op + ) + + params_ = self.operation.rt_module.get_args( + ctypes.byref(self.c_arguments)) + self.host_workspace = bytearray(params_.contents) + + def sync(self): + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + if hasattr(self, "host_D"): + err, = cuda.cuMemcpyDtoH( + self.host_D, self.ptr_destination, self.host_D.size * self.host_D.itemsize) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + def free(self): + if hasattr(self, "destination_buffer"): + del self.destination_buffer + if hasattr(self, "source_buffer"): + del self.source_buffer + + +class ReductionRT(ExecutableOperation): + """ + ReductionRT manages the CUTLASS runtime components for reduction + """ + KernelTemplate = r''' +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + ''' + HostTemplate = r''' +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){ + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) + output[i] = bytes[i]; + + return output; + } +} + ''' + + def __init__(self, operation: ReductionOperation): + super().__init__(operation) + + self.operation: ReductionOperation = operation + self.emitter = EmitReductionInstance('_type') + + self.elements_per_access = self.operation.count + self.argtype = [ctypes.POINTER( + get_reduction_params(operation.element_compute)[0])] + + def emit(self): + return self.emitter.emit(self.operation) + + def plan(self, arguments: ReductionArguments): + block_shape = [self.operation.shape.column( + ) // self.elements_per_access, self.operation.shape.row(), 1] + grid_shape = [ + (arguments.problem_size.row + self.operation.shape.row() - + 1) // self.operation.shape.row(), + (arguments.problem_size.column + self.operation.shape.column() - + 1) // self.operation.shape.column(), + 1 + ] + return LaunchConfiguration(grid_shape, block_shape, self.shared_memory_capacity) + + def initialize(self): + err, = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Cuda Error: {}'.format(err)) + + +class ReductionOperation: + """ + CUTLASS Reduction Operation + shape: shape of CTA + outputop: output operator + r + """ + + def __init__(self, shape: cutlass.MatrixCoord, C: TensorDescription, + element_accumulator, element_workspace=None, + element_compute=None, epilogue_functor: EpilogueFunctor = EpilogueFunctor.LinearCombination, + count: int = 1, partitions_per_stage: int = 4) -> None: + """ Constructor + """ + + self.shape = shape + #: epilogue functor (default: LinearCombination) + self.epilogue_functor: EpilogueFunctor = epilogue_functor + #: datatype of accumulator + self.element_accumulator = element_accumulator + + if element_workspace is None: + #: datatype of workspace + self.element_workspace = element_accumulator + else: + #: datatype of workspace + self.element_workspace = element_workspace + + if element_compute is None: + #: datatype of workspace + self.element_compute = element_accumulator + else: + #: datatype of workspace + self.element_compute = element_compute + + #: datatype of output + self.element_output = C.element + + #: operand C + self.C: TensorDescription = C + + #: reduce op processing size + self.count: int = count + + #: number of partitions to reduce per stage + self.partitions_per_stage: int = partitions_per_stage + + self.rt_module: ReductionRT = ReductionRT(self) + + # + def extended_name(self): + extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}" + + return SubstituteTemplate(extend_name, + { + 'element_workspace': DataTypeNames[self.element_workspace], + 'element_accumulator': DataTypeNames[self.element_accumulator], + 'element_compute': DataTypeNames[self.element_compute], + 'element_output': DataTypeNames[self.element_output] + }) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size''' + + configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}" + + threadblock = "%dx%d" % ( + self.shape.row(), + self.shape.column() + ) + + return SubstituteTemplate( + configuration_name, + { + 'extended_name': self.extended_name(), + 'threadblock': threadblock + } + ) + + # + def procedural_name(self): + ''' The full procedural name indicates architeture, extended name, tile size''' + return self.configuration_name() + + def run(self, arguments: ReductionArguments) -> cuda.CUresult: + """ + Configure and launch the cuda kernel with input arguments + """ + # get launch configuration + launch_config = self.rt_module.plan(arguments) + + # get the host and device workspace + host_workspace = arguments.host_workspace + device_workspace = None + + # launch the kernel + err = self.rt_module.run( + host_workspace, device_workspace, launch_config) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('CUDA Error %s' % str(err)) + + return err + + +class EmitReductionInstance: + def __init__(self, operation_suffix='') -> None: + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + "cutlass/reduction/kernel/reduce_split_k.h", + "cutlass/reduction/thread/reduction_operators.h" + ] + self.template = """ +// Reduction kernel instance +using ${operation_name}_base = +typename cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<${shape_row}, ${shape_column}>, + ${epilogue_functor}< + ${element_output}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_compute} + >, + cutlass::reduction::thread::ReduceAdd< + ${element_accumulator}, + ${element_output}, + ${count}>, + ${partition_per_stage}>; + +struct ${operation_name}${operation_suffix}: + public ${operation_name}_base { }; + """ + + def emit(self, operation: ReductionOperation): + + epilogue_vector_length = int(min( + operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.configuration_name(), + 'operation_suffix': self.operation_suffix, + 'shape_row': str(operation.shape.row()), + 'shape_column': str(operation.shape.column()), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_output': DataTypeTag[operation.element_output], + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_accumulator': DataTypeTag[operation.element_accumulator], + 'element_compute': DataTypeTag[operation.element_compute], + 'element_workspace': DataTypeTag[operation.element_workspace], + 'count': str(operation.count), + 'partition_per_stage': str(operation.partitions_per_stage) + } + + return SubstituteTemplate(self.template, values) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py b/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py new file mode 100644 index 00000000..4d2b89e6 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py @@ -0,0 +1,71 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +from typeguard import typechecked +import numpy as np +try: + import torch + torch_available = True +except ImportError: + torch_available = False +from cuda import cuda +try: + import cupy as cp + cupy_available = True +except ImportError: + cupy_available = False +import cutlass + + +# @typechecked +class TensorRef: + """ + Python Wrapper for cutlass.TensorRef + """ + def __init__(self, tensor, dtype, layout) -> None: + if isinstance(tensor, np.ndarray): + ptr = cuda.CUdeviceptr(tensor.__array_interface__['data'][0]) + elif torch_available and isinstance(tensor, torch.Tensor): + ptr = cuda.CUdeviceptr(tensor.data_ptr()) + elif cupy_available and isinstance(tensor, cp.ndarray): + ptr = cuda.CUdeviceptr(int(tensor.data.ptr)) + elif isinstance(tensor, cuda.CUdeviceptr): + ptr = tensor + elif isinstance(tensor, int): + ptr = cuda.CUdeviceptr(tensor) + else: + raise NotImplementedError(tensor) + + # the dtype(0) is used to overload between different data types + # with the same layout + self.tensor_ref = cutlass.get_tensor_ref(int(ptr), dtype(0), layout) + diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py new file mode 100644 index 00000000..dacdc43a --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py @@ -0,0 +1,4 @@ +from pycutlass.test.profiler import * +from pycutlass.test.conv2d_testbed import * +from pycutlass.test.gemm_testbed import * +from pycutlass.test.gemm_grouped_testbed import * diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py new file mode 100644 index 00000000..5f9cd3c1 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py @@ -0,0 +1,646 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pycutlass +from pycutlass import * +from pycutlass.test import * +from time import sleep +from bfloat16 import bfloat16 +import subprocess +from typeguard import typechecked +import re + + + +def getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand): + ptr = tensor.__array_interface__['data'][0] + if operand == "a": + tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(conv_kind, problem_size) + elif operand == "b": + tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(conv_kind, problem_size) + elif operand in ["c", "d"]: + tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(conv_kind, problem_size) + else: + raise ValueError("unknown operand: " + operand) + + layout = tensor_layout.packed(tensor_coord) + + if tensor.dtype == np.float64: + return cutlass.TensorRefF64NHWC(ptr, layout) + elif tensor.dtype == np.float32: + return cutlass.TensorRefF32NHWC(ptr, layout) + elif tensor.dtype == np.float16: + return cutlass.TensorRefF16NHWC(ptr, layout) + if tensor.dtype == bfloat16: + return cutlass.TensorRefBF16NHWC(ptr, layout) + elif tensor.dtype == np.int32: + return cutlass.TensorRefS32NHWC(ptr, layout) + elif tensor.dtype == np.int8: + if tensor_layout == cutlass.TensorNC32HW32: + return cutlass.TensorRefS8NC32HW32(ptr, layout) + elif tensor_layout == cutlass.TensorC32RSK32: + return cutlass.TensorRefS8C32RSK32(ptr, layout) + else: + return cutlass.TensorRefS8NHWC(ptr, layout) + else: + raise ValueError("unsupported data type") + +def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand): + tensor_ref = getTensorRef(tensor, tensor_layout, conv_kind, problem_size, operand) + + if operand == "a": + tensor_coord = cutlass.conv.implicit_gemm_tensor_a_extent(conv_kind, problem_size) + elif operand == "b": + tensor_coord = cutlass.conv.implicit_gemm_tensor_b_extent(conv_kind, problem_size) + elif operand in ["c", "d"]: + tensor_coord = cutlass.conv.implicit_gemm_tensor_c_extent(conv_kind, problem_size) + else: + raise ValueError("unknown operand: " + operand) + + if tensor.dtype == np.float64: + return cutlass.TensorViewF64NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.float32: + return cutlass.TensorViewF32NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.float16: + return cutlass.TensorViewF16NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == bfloat16: + return cutlass.TensorViewBF16NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.int32: + return cutlass.TensorViewS32NHWC(tensor_ref, tensor_coord) + elif tensor.dtype == np.int8: + if tensor_layout == cutlass.TensorNC32HW32: + return cutlass.TensorViewS8NC32HW32(tensor_ref, tensor_coord) + elif tensor_layout == cutlass.TensorC32RSK32: + return cutlass.TensorViewS8C32RSK32(tensor_ref, tensor_coord) + else: + return cutlass.TensorViewS8NHWC(tensor_ref, tensor_coord) + + else: + raise ValueError("unsupported data type") + + + +# @typechecked +class Conv2dLauncher: + """ + Launcher that runs the operation on given problem size + """ + def __init__(self, operation: 'Conv2dOperation', seed: int=2080, interleaved=False, + verification=True, profiling=False, warmup_iterations=500, iterations=500, **kwargs) -> None: + + self.enable_cached_results = True + self.interleaved = interleaved + + # create the reduction kernel + self.reduction_operation = ReductionOperation( + shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.element_epilogue, + count=operation.C.alignment + ) + + #: verify the output result + self.verification = verification + #: profile the kernel's runtime + self.profiling = profiling + + self.timer = GpuTimer() + + self.warmup_iterations = warmup_iterations + self.iterations = iterations + + if "sleep" in kwargs.keys(): + self.sleep_time = kwargs["sleep"] + else: + self.sleep_time = 0 + + # + # Compile the operator + # + + pycutlass.compiler.add_module([operation, self.reduction_operation]) + + self.operation = operation + + self.dtype_A = Conv2dLauncher.numpy_type(operation.A.element) + self.layout_A = operation.A.layout + self.dtype_B = Conv2dLauncher.numpy_type(operation.B.element) + self.layout_B = operation.B.layout + self.dtype_C = Conv2dLauncher.numpy_type(operation.C.element) + self.layout_C = operation.C.layout + self.dtype_D = Conv2dLauncher.numpy_type(operation.C.element) + self.layout_D = operation.C.layout + + accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator] + element_size = DataTypeSize[operation.A.element] + + if element_size <= 8: + self.scope = 1 + elif element_size == 16: + if accumulator_size <= 16: + self.scope = 2 + else: + self.scope = 4 + else: + self.scope = 7 + + # Seed + self.seed = seed + + self.conv_kind = operation.conv_kind + + + # + # Get the host reference function + # + + self.element_compute = operation.element_epilogue + + self.host_conv2d = cutlass.test.conv.host.conv2d + + self.timer = GpuTimer() + + @staticmethod + def numpy_type(type): + if type == cutlass.float64: + return np.float64 + elif type == cutlass.float32: + return np.float32 + elif type == cutlass.float16: + return np.float16 + elif type == cutlass.bfloat16: + return bfloat16 + elif type == cutlass.int32: + return np.int32 + elif type == cutlass.int8: + return np.int8 + else: + raise ValueError("unsupported type: %s" % ShortDataTypeNames[type]) + + def print_problem_size(self, p, split_k_mode=1): + print("nhwc_%dx%dx%dx%d_krsc_%dx%dx%dx%d_padding_%dx%d_stride_%dx%d_dilation_%dx%d_splitkslices_%d_splitkmode_%d" + % (p.N, p.H, p.W, p.C, p.K, p.R, p.S, p.C, p.pad_h, + p.pad_w, p.stride_h, p.stride_w, p.dilation_h, p.dilation_w, p.split_k_slices, split_k_mode)) + + def uniform_init(self, size, dtype): + if dtype in [np.float32, np.float16, bfloat16, np.float64]: + return np.ceil( + np.random.uniform( + low=-self.scope - 0.5, high=self.scope - 0.5, + size=size).astype(dtype) + ) + else: + return np.random.uniform( + low=-self.scope - 1, high=self.scope + 1, + size=size).astype(dtype) + + def eq_gemm_size(self, problem_size): + n = problem_size.N + p = problem_size.P + q = problem_size.Q + k = problem_size.K + r = problem_size.R + s = problem_size.S + c = problem_size.C + h = problem_size.H + w = problem_size.W + if self.conv_kind == cutlass.conv.Operator.fprop: + return cutlass.gemm.GemmCoord(n * p * q, k, r * s * c) + elif self.conv_kind == cutlass.conv.Operator.dgrad: + return cutlass.gemm.GemmCoord(n * h * w, c, k * r * s) + else: + return cutlass.gemm.GemmCoord(k, r * s * c, n * p * q) + + def bytes(self, problem_size, alpha, beta): + mnk = self.eq_gemm_size(problem_size) + + bytes_ = \ + (DataTypeSize[self.operation.A.element] * mnk.m() // 8) * mnk.k() + \ + (DataTypeSize[self.operation.B.element] * mnk.n() // 8) * mnk.k() + \ + (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n() + + if beta != 0: + bytes_ += (DataTypeSize[self.operation.C.element] * mnk.m() // 8) * mnk.n() + + return bytes_ + + def flops(self, problem_size): + mnk = self.eq_gemm_size(problem_size) + + flops_mainloop_ = mnk.m() * mnk.n() * mnk.k() * 2 + flops_epilogue_ = mnk.m() * mnk.n() * 2 + + # Adjust mainloop flop for dgrad stride + if self.conv_kind == cutlass.conv.Operator.dgrad: + flops_mainloop_ = flops_mainloop_ // (problem_size.stride_h * problem_size.stride_w) + + flops_total_ = flops_mainloop_ + flops_epilogue_ + + # TODO complex-value support + # switch (operation_desc.tile_description.math_instruction.math_operation) { + # case library::MathOperationID::kMultiplyAddComplex: + # flops_total_ *=4; + # break; + + # default: break; + # } + + return flops_total_ + + + + def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): + if self.element_compute == cutlass.float16: + alpha = cutlass.float16(alpha) + beta = cutlass.float16(beta) + elif self.element_compute == cutlass.int32: + alpha = int(alpha) + beta = int(beta) + else: + alpha = alpha + beta = beta + + # if cached result is loaded + cached_result_loaded = False + + if self.enable_cached_results: + # get problem key + cached_test_key = cutlass.test.conv.host.CreateCachedConv2dTestKey( + self.conv_kind, problem_size, alpha, beta, + getTensorView(tensor_A, self.layout_A, self.conv_kind, problem_size, "a"), + getTensorView(tensor_B, self.layout_B, self.conv_kind, problem_size, "b"), + getTensorView(tensor_C, self.layout_C, self.conv_kind, problem_size, "c"), + ) + + cached_test_result = cutlass.test.conv.host.CachedTestResult() + + conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % (self.operation.arch, self.seed) + + cached_results = cutlass.test.conv.host.CachedTestResultListing(conv2d_result_cache_name) + # CachedTestResultListing cached_results(conv2d_result_cache_name); + cached = cached_results.find(cached_test_key) + cached_result_loaded = cached[0] + if cached_result_loaded : + cached_test_result = cached[1] + + if not cached_result_loaded: + # compute the conv2d on host + tensor_D_ref = np.ones_like(tensor_C) + tensor_ref_A = getTensorRef(tensor_A, self.layout_A, self.conv_kind, problem_size, "a") + tensor_ref_B = getTensorRef(tensor_B, self.layout_B, self.conv_kind, problem_size, "b") + tensor_ref_C = getTensorRef(tensor_C, self.layout_C, self.conv_kind, problem_size, "c") + tensor_ref_D_ref = getTensorRef(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d") + + self.host_conv2d( + self.conv_kind, problem_size, + tensor_ref_A, tensor_ref_B, tensor_ref_C, tensor_ref_D_ref, + alpha, beta + ) + + tensor_view_D_ref = getTensorView(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d") + + if self.enable_cached_results: + cached_test_result.D = cutlass.test.conv.host.TensorHash(tensor_view_D_ref) + cached_results = cutlass.test.conv.host.CachedTestResultListing(conv2d_result_cache_name) + cached_results.append(cached_test_key, cached_test_result) + cached_results.write(conv2d_result_cache_name) + else: + return tensor_D_ref + + return cached_test_result.D + + def equal(self, tensor_D, tensor_D_ref, problem_size): + if self.enable_cached_results: + tensor_view_D = getTensorView(tensor_D, self.layout_D, self.conv_kind, problem_size, "d") + tensor_D_hash = cutlass.test.conv.host.TensorHash(tensor_view_D) + + return tensor_D_hash == tensor_D_ref + else: + tensor_view_D = getTensorView(tensor_D, self.layout_D, self.conv_kind, problem_size, "d") + tensor_view_D_ref = getTensorView(tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d") + return cutlass.test.conv.host.equals(tensor_view_D, tensor_view_D_ref) + + def run_cutlass_profiler(self, problem_size, split_k_mode=cutlass.conv.SplitKMode.Serial, alpha=1.0, beta=0.0): + + if split_k_mode == cutlass.conv.SplitKMode.Serial: + split_k_mode_ = "serial" + else: + split_k_mode_ = "parallel" + + cutlass_path = os.getenv('CUTLASS_PATH') + assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." + + values = { + "profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler", + "kernel_name": self.operation.procedural_name(), + "verification_providers": "device", + "provider": "cutlass", + 'n': str(problem_size.N), + 'h': str(problem_size.H), + 'w': str(problem_size.W), + 'c': str(problem_size.C), + 'k': str(problem_size.K), + 'r': str(problem_size.R), + 's': str(problem_size.S), + 'p': str(problem_size.P), + 'q': str(problem_size.Q), + 'pad_h': str(problem_size.pad_h), + 'pad_w': str(problem_size.pad_w), + 'stride_h': str(problem_size.stride_h), + 'stride_w': str(problem_size.stride_w), + 'dilation_h': str(problem_size.dilation_h), + 'dilation_w': str(problem_size.dilation_w), + 'split_k_slices': str(problem_size.split_k_slices), + 'split_k_mode': split_k_mode_, + 'alpha': str(alpha), + 'beta': str(beta), + 'warmup': str(self.warmup_iterations), + 'profile': str(self.iterations) + } + + cmd_template = \ + "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" \ + " --providers=${provider} --n=${n} --h=${h} --w=${w} --c=${c} --k=${k} --r=${r} --s=${s} --p=${p}" \ + " --q=${q} --pad_h=${pad_h} --pad_w=${pad_w} --stride_h={stride_h} --stride_w=${stride_w}" \ + " --dilation_h=${dilation_h} --dilation_w=${dilation_w} --warmup-iterations=${warmup} --profiling-iterations=${profile}" \ + " --split_k_slices=${split_k_slices} --alpha=${alpha} --beta=${beta} --split_k_mode=${split_k_mode}" + + cmd = SubstituteTemplate(cmd_template, values) + result = subprocess.getoutput(cmd) + + m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) + runtime = float(m.group('runtime')) + + m = re.search(r"Bytes:\s+(?P\d+)", result) + bytes = int(m.group('bytes')) + + m = re.search(r"FLOPs:\s+(?P\d+)", result) + flops = int(m.group('flops')) + + # check if the problem size matches + assert bytes == self.bytes(problem_size, alpha, beta) + assert flops == self.flops(problem_size) + + return runtime + + + + def run(self, problem_size, split_k_mode=cutlass.conv.SplitKMode.Serial, + alpha=1.0, beta=0.0): + + assert get_allocated_size() == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size() + + # + # Initialize input and output tensors + # + tensor_A_size = cutlass.conv.implicit_gemm_tensor_a_size(self.conv_kind, problem_size) + tensor_B_size = cutlass.conv.implicit_gemm_tensor_b_size(self.conv_kind, problem_size) + tensor_C_size = cutlass.conv.implicit_gemm_tensor_c_size(self.conv_kind, problem_size) + + np.random.seed(self.seed) + + tensor_A = self.uniform_init(size=(tensor_A_size,), dtype=self.dtype_A) + tensor_B = self.uniform_init(size=(tensor_B_size,), dtype=self.dtype_B) + tensor_C = self.uniform_init(size=(tensor_C_size,), dtype=self.dtype_C) + tensor_D = np.zeros(shape=(tensor_C_size,), dtype=self.dtype_D) + + + # + # Launch kernel + # + + arguments = Conv2dArguments( + operation=self.operation, problem_size=problem_size, A=tensor_A, + B=tensor_B, C=tensor_C, D=tensor_D, + output_op = LinearCombinationFunctorArguments(alpha, beta), + split_k_slices=problem_size.split_k_slices, + split_k_mode=split_k_mode + ) + + if split_k_mode == cutlass.conv.SplitKMode.Parallel: + implicit_gemm_size = cutlass.conv.implicit_gemm_problem_size(self.operation.conv_kind, arguments.problem_size) + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], partitions=problem_size.split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op = LinearCombinationFunctorArguments(alpha, beta) + ) + + self.operation.run(arguments) + if split_k_mode == cutlass.conv.SplitKMode.Parallel: + self.reduction_operation.run(reduction_arguments) + + passed = True + if self.verification: + if split_k_mode == cutlass.conv.SplitKMode.Parallel: + reduction_arguments.sync() + else: + arguments.sync() + + tensor_D_ref = self.host_reference(problem_size, tensor_A, tensor_B, tensor_C, alpha, beta) + + passed = self.equal(tensor_D, tensor_D_ref, problem_size) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size, split_k_mode) + + if self.profiling: + sleep(self.sleep_time) + for _ in range(self.warmup_iterations): + self.operation.run(arguments) + if split_k_mode == cutlass.conv.SplitKMode.Parallel: + self.reduction_operation.run(reduction_arguments) + + self.timer.start() + for _ in range(self.warmup_iterations): + self.operation.run(arguments) + if split_k_mode == cutlass.conv.SplitKMode.Parallel: + self.reduction_operation.run(reduction_arguments) + self.timer.stop_and_wait() + runtime = self.timer.duration(self.iterations) + + # free memory + del arguments + if split_k_mode == cutlass.conv.SplitKMode.Parallel: + del reduction_arguments + + assert get_allocated_size() == 0, "%d byte of pool memory is not released after current run" % get_allocated_size() + if self.profiling: + return runtime + return passed + + + +######################################################################################################## +# TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +# TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +# Additionaly, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +# (conv_blacklist_sizes) +############################################################################################################ + +def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes = [], interleaved=False): # TODO: conv_test_sizes and conv_blacklist_sizes + passed = True + + # + # Testbed object + # + + testbed = Conv2dLauncher(operation, interleaved=interleaved) + + # + # Get conv problem sizes to run conv operator + # + + conv_problems = cutlass.test.conv.TestbedConv2dProblemSizes(64) + + # Vector of conv2d problem sizes to avoid duplicate runs + conv_tested_sizes = [] + + # TODO: include resnet 50 sizes, user sepecified sizes, and rigorous sizes + + # Flatten 2D problem_vectors into a 1D problem sizes + problem_sizes = conv_problems.conv2d_default_sizes + + problem_sizes = [conv_problem for conv_problem in problem_sizes] + conv_test_sizes + + # Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slices=1, alpha=1.0, beta=0.0) + for conv_problem in problem_sizes: + + # TODO: skip blacklist problem sizes + if conv_problem in conv_tested_sizes: + continue + + # skip channel dimension % 32 != 0 for interleaved case + if interleaved: + if conv_problem.K % 32 != 0 or conv_problem.C % 32 != 0: + continue + + # + # Procedurally disable certain cases + # + + # CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Unity: + if not ((conv_problem.stride_h == 1) and (conv_problem.stride_w == 1)): + continue + + if not interleaved: + # Fixed channels algorithm requires channel count to match access size + if operation.iterator_algorithm == cutlass.conv.IteratorAlgorithm.fixed_channels: + if conv_problem.C != operation.A.alignment: + continue + + # Few channels algorithm requires channel count to match access size + if operation.iterator_algorithm == cutlass.conv.IteratorAlgorithm.few_channels: + if conv_problem.C % operation.A.alignment: + continue + + # CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + # Although strided dgrad works for all stride combinations, we are only going + # to run strided dgrad for non-unity strides + + if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Strided: + if (conv_problem.stride_h == 1) and (conv_problem.stride_w == 1): + continue + + # + # Test + # + + # push back tested problem size to avoid re-running duplicates + conv_tested_sizes.append(conv_problem) + + passed = testbed.run(conv_problem) + + # if not passed: return False + + # TODO: If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the the number of tested problem counts + + if interleaved: + return True + # + # filter the cases for split K + # + + # Small-channels convolution can't run here. + if operation.iterator_algorithm in [cutlass.conv.IteratorAlgorithm.fixed_channels, cutlass.conv.IteratorAlgorithm.few_channels]: + return True + + # CUTLASS DGRAD's *stride* specialization does not support split-k mode + if operation.conv_kind == cutlass.conv.Operator.dgrad and operation.stride_support == StrideSupport.Strided: + conv_problem = cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 56, 56, 8), + cutlass.Tensor4DCoord(8, 1, 1, 8), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ) + passed = testbed.run(conv_problem) + + return passed + + # Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + # a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + # which are abolutely neccessary to catch functional bugs. The below code does provide option to sweep + # alpha and beta for local testing, but only runs one value for alpha and beta. + + conv2d_split_k_test_size = cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 17, 11, 288), + cutlass.Tensor4DCoord(160, 3, 3, 288), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ) + + split_k_modes = [cutlass.conv.SplitKMode.Parallel, cutlass.conv.SplitKMode.Serial] + + split_k_slices = [1, 2, 3, 4, 201] + problem_alpha = [2.0,] + problem_beta = [2.0,] + + for split_k_mode in split_k_modes: + for split_k_slice in split_k_slices: + for alpha in problem_alpha: + for beta in problem_beta: + passed = testbed.run(conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + alpha, beta) + + return passed diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py new file mode 100644 index 00000000..467d965f --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py @@ -0,0 +1,235 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pycutlass +from pycutlass.test.gemm_testbed import getTensorRef, getTensorView, transpose +from pycutlass import * +import numpy as np +import cutlass +from bfloat16 import bfloat16 + + +class TestbedGrouped: + def __init__(self, operation: GemmOperationGrouped, seed: int = 2080) -> None: + + pycutlass.compiler.add_module([operation]) + + self.seed = seed + + self.operation = operation + + element_size = DataTypeSize[operation.A.element] + + self.dtype_A = self.numpy_type(operation.A.element) + self.dtype_B = self.numpy_type(operation.B.element) + self.dtype_C = self.numpy_type(operation.C.element) + self.dtype_D = self.numpy_type(operation.C.element) + + if element_size == 1: + self.scope_max = 1 + self.scope_min = 0 + elif element_size <= 8: + self.scope_max = 1 + self.scope_min = -1 + elif element_size == 16: + self.scope_max = 4 + self.scope_min = -4 + else: + self.scope_max = 8 + self.scope_min = -8 + + #: compute type + self.compute_type = operation.element_epilogue + + self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + + @staticmethod + def numpy_type(type): + if type == cutlass.float64: + return np.float64 + elif type == cutlass.float32: + return np.float32 + elif type == cutlass.float16: + return np.float16 + elif type == cutlass.bfloat16: + return bfloat16 + elif type == cutlass.int32: + return np.int32 + elif type == cutlass.int8: + return np.int8 + else: + raise ValueError("unsupported type: %s" % ShortDataTypeNames[type]) + + def uniform_init(self, size, dtype): + if dtype in [np.float32, np.float16, bfloat16, np.float64]: + return np.ceil( + np.random.uniform( + low=self.scope_min - 0.5, high=self.scope_max - 0.5, + size=size).astype(dtype) + ) + else: + return np.random.uniform( + low=self.scope_min - 1, high=self.scope_max + 1, + size=size).astype(dtype) + + def print_problem_size(self, p): + problem_size = "problem: %d, %d, %d\n" % (p.m(), p.n(), p.k()) + print(problem_size) + + def run(self, problem_count: int, alpha: float = 1.0, beta: float = 0.0) -> bool: + + assert get_allocated_size( + ) == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size() + + # initialize + np.random.seed(self.seed) + + # generate the problem sizes + problem_sizes = [] + tensor_As = [] + tensor_Bs = [] + tensor_Cs = [] + tensor_Ds = [] + tensor_D_refs = [] + + for i in range(problem_count): + if self.dtype_A == np.int8: + if i == 0: + problem_size = cutlass.gemm.GemmCoord(48, 16, 32) + else: + problem_size = cutlass.gemm.GemmCoord( + 16 * np.random.randint(0, 64) + 48, + 16 * np.random.randint(0, 64) + 48, + 16 * np.random.randint(0, 64) + 48 + ) + else: + if i == 0: + problem_size = cutlass.gemm.GemmCoord(48, 16, 8) + else: + problem_size = cutlass.gemm.GemmCoord( + 8 * np.random.randint(0, 64) + 24, + 8 * np.random.randint(0, 64) + 24, + 8 * np.random.randint(0, 64) + 24 + ) + + tensor_As.append( + self.uniform_init( + size=(problem_size.m() * problem_size.k(),), + dtype=self.dtype_A) + ) + tensor_Bs.append( + self.uniform_init( + size=(problem_size.n() * problem_size.k(),), + dtype=self.dtype_B) + ) + tensor_Cs.append( + self.uniform_init( + size=(problem_size.m() * problem_size.n(),), + dtype=self.dtype_C) + ) + + tensor_Ds.append( + np.zeros( + shape=(problem_size.m() * problem_size.n(),), + dtype=self.dtype_D + ) + ) + + tensor_D_refs.append( + np.ones( + shape=(problem_size.m() * problem_size.n(),), + dtype=self.dtype_D + ) + ) + + problem_sizes.append(problem_size) + + arguments = GemmGroupedArguments( + operation=self.operation, problem_sizes=problem_sizes, + A=tensor_As, B=tensor_Bs, C=tensor_Cs, D=tensor_Ds, + output_op=LinearCombinationFunctorArguments(alpha, beta) + ) + + self.operation.run(arguments) + + arguments.sync() + + # + # Reference check - TODO: support caching results + # + alpha = self.compute_type(alpha).value() + beta = self.compute_type(beta).value() + init_acc = self.accumulator_type(0).value() + + for idx, problem_size in enumerate(problem_sizes): + if self.operation.switched: + tensor_ref_A = getTensorRef( + tensor_As[idx], problem_size, "a", transpose(self.operation.B.layout)) + tensor_ref_B = getTensorRef( + tensor_Bs[idx], problem_size, "b", transpose(self.operation.A.layout)) + tensor_ref_C = getTensorRef( + tensor_Cs[idx], problem_size, "c", transpose(self.operation.C.layout)) + tensor_ref_D_ref = getTensorRef( + tensor_D_refs[idx], problem_size, "d", transpose(self.operation.C.layout)) + else: + tensor_ref_A = getTensorRef( + tensor_As[idx], problem_size, "a", self.operation.A.layout) + tensor_ref_B = getTensorRef( + tensor_Bs[idx], problem_size, "b", self.operation.B.layout) + tensor_ref_C = getTensorRef( + tensor_Cs[idx], problem_size, "c", self.operation.C.layout) + tensor_ref_D_ref = getTensorRef( + tensor_D_refs[idx], problem_size, "d", self.operation.C.layout) + + tensor_view_D_ref = getTensorView( + tensor_D_refs[idx], problem_size, "d", self.operation.C.layout) + + cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A, + tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + + tensor_view_D = getTensorView( + tensor_Ds[idx], problem_size, "d", self.operation.C.layout) + + passed = cutlass.test.gemm.host.equals( + tensor_view_D, tensor_view_D_ref) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size) + + del arguments + + assert get_allocated_size( + ) == 0, "%d byte of pool memory is not released after current run" % get_allocated_size() + + return passed diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py new file mode 100644 index 00000000..344f20ec --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py @@ -0,0 +1,557 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from time import sleep +import pycutlass +from pycutlass import * +import cutlass +from cuda import cudart +from cuda import cuda +from bfloat16 import bfloat16 +from .profiler import GpuTimer +import subprocess + + +def transpose(layout): + if layout == cutlass.RowMajor: + return cutlass.ColumnMajor + elif layout == cutlass.ColumnMajor: + return cutlass.RowMajor + elif layout == cutlass.ColumnMajorInterleaved32: + return cutlass.RowMajorInterleaved32 + elif layout == cutlass.RowMajorInterleaved32: + return cutlass.ColumnMajorInterleaved32 + + +def getTensorRef(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: cutlass.layout): + ptr = tensor.__array_interface__['data'][0] + if operand == "a": + tensor_coord = problem_size.mk() + elif operand == "b": + tensor_coord = problem_size.kn() + elif operand in ["c", "d"]: + tensor_coord = problem_size.mn() + else: + raise ValueError("unknonw operand: " + operand) + + if layout == cutlass.RowMajor: + layout = cutlass.RowMajor.packed(tensor_coord) + layout_tag = "RowMajor" + elif layout == cutlass.ColumnMajor: + layout = cutlass.ColumnMajor.packed(tensor_coord) + layout_tag = "ColumnMajor" + elif layout == cutlass.ColumnMajorInterleaved32: + layout = cutlass.ColumnMajorInterleaved32.packed(tensor_coord) + layout_tag = "ColumnMajorInterleaved32" + elif layout == cutlass.RowMajorInterleaved32: + layout = cutlass.RowMajorInterleaved32.packed(tensor_coord) + layout_tag = "RowMajorInterleaved32" + else: + raise ValueError("unsupported layout") + if tensor.dtype == np.float32: + ref_name = "TensorRefF32" + layout_tag + elif tensor.dtype == np.float64: + ref_name = "TensorRefF64" + layout_tag + elif tensor.dtype == np.float16: + ref_name = "TensorRefF16" + layout_tag + elif tensor.dtype == bfloat16: + ref_name = "TensorRefBF16" + layout_tag + elif tensor.dtype == np.int8: + ref_name = "TensorRefS8" + layout_tag + elif tensor.dtype == np.int32: + ref_name = "TensorRefS32" + layout_tag + else: + raise ValueError("unsupported datatype %s" % + ShortDataTypeNames[tensor.dtype]) + + return getattr(cutlass, ref_name)(ptr, layout) + + +def getTensorView(tensor: np.ndarray, problem_size: cutlass.gemm.GemmCoord, operand: str, layout: str): + tensor_ref = getTensorRef(tensor, problem_size, operand, layout) + + if operand == "a": + tensor_coord = problem_size.mk() + elif operand == "b": + tensor_coord = problem_size.kn() + elif operand in ["c", "d"]: + tensor_coord = problem_size.mn() + else: + raise ValueError("unknonw operand: " + operand) + + if layout == cutlass.RowMajor: + layout_tag = "RowMajor" + elif layout == cutlass.ColumnMajor: + layout_tag = "ColumnMajor" + elif layout == cutlass.ColumnMajorInterleaved32: + layout_tag = "ColumnMajorInterleaved32" + elif layout == cutlass.RowMajorInterleaved32: + layout_tag = "RowMajorInterleaved32" + else: + raise ValueError("unsupported layout") + if tensor.dtype == np.float32: + ref_name = "TensorViewF32" + layout_tag + elif tensor.dtype == np.float64: + ref_name = "TensorViewF64" + layout_tag + elif tensor.dtype == np.float16: + ref_name = "TensorViewF16" + layout_tag + elif tensor.dtype == bfloat16: + ref_name = "TensorViewBF16" + layout_tag + elif tensor.dtype == np.int32: + ref_name = "TensorViewS32" + layout_tag + elif tensor.dtype == np.int8: + ref_name = "TensorViewS8" + layout_tag + else: + raise ValueError("unsupported datatype") + + return getattr(cutlass, ref_name)(tensor_ref, tensor_coord) + + +class GemmUniversalLauncher: + def __init__(self, operation: 'GemmOperationUniversal', seed: int = 2080, interleaved=False, + verification=True, profiling=False, warmup_iterations=500, iterations=500, **kwargs) -> None: + # create the reduction kernel + self.reduction_operation: ReductionOperation = ReductionOperation( + shape=cutlass.MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.element_epilogue, + count=operation.C.alignment + ) + + self.math_operation = operation.tile_description.math_instruction.math_operation + + #: verify the output result + self.verification = verification + #: profile the kernel's runtime + self.profiling = profiling + + self.timer = GpuTimer() + + self.warmup_iterations = warmup_iterations + self.iterations = iterations + + if "sleep" in kwargs.keys(): + self.sleep_time = kwargs["sleep"] + else: + self.sleep_time = 0 + + # + # Compile the operator + # + + pycutlass.compiler.add_module([operation, self.reduction_operation]) + + self.operation = operation + + self.dtype_A = GemmUniversalLauncher.numpy_type(operation.A.element) + self.dtype_B = GemmUniversalLauncher.numpy_type(operation.B.element) + self.dtype_C = GemmUniversalLauncher.numpy_type(operation.C.element) + self.dtype_D = GemmUniversalLauncher.numpy_type(operation.C.element) + + accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator] + element_size = DataTypeSize[operation.A.element] + + if element_size == 1: + self.scope_max = 1 + self.scope_min = 0 + elif element_size <= 8: + self.scope_max = 1 + self.scope_min = -1 + elif element_size == 16: + self.scope_max = 4 + self.scope_min = -4 + else: + self.scope_max = 8 + self.scope_min = -8 + + #: seed + self.seed: int = seed + + #: whether the layout is interleaved + self.interleaved = interleaved + + #: compute type + self.compute_type = operation.element_epilogue + self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + + def print_problem_size(self, p, mode, batch_count): + if mode == cutlass.gemm.Mode.Gemm: + mode = "Gemm" + elif mode == cutlass.gemm.Mode.GemmSplitKParallel: + mode = "GemmSplitKParalel" + problem_size = "problem: %d, %d, %d\n batch_count: %d\n mode: %s" % ( + p.m(), p.n(), p.k(), batch_count, mode) + print(problem_size) + + @staticmethod + def numpy_type(type): + if type == cutlass.float64: + return np.float64 + elif type == cutlass.float32: + return np.float32 + elif type == cutlass.float16: + return np.float16 + elif type == cutlass.bfloat16: + return bfloat16 + elif type == cutlass.int32: + return np.int32 + elif type == cutlass.int8: + return np.int8 + else: + raise ValueError("unsupported type: %s" % ShortDataTypeNames[type]) + + def uniform_init(self, size, dtype): + if dtype in [np.float32, np.float16, bfloat16, np.float64]: + return np.ceil( + np.random.uniform( + low=self.scope_min - 0.5, high=self.scope_max - 0.5, + size=size).astype(dtype) + ) + else: + return np.random.uniform( + low=self.scope_min - 1, high=self.scope_max + 1, + size=size).astype(dtype) + + def reorder_tensor_B(self, tensor_B, problem_size): + reordered_tensor_B = np.empty_like(tensor_B) + tensor_ref_B = getTensorRef( + tensor_B, problem_size, "b", self.operation.B.layout) + reordered_tensor_ref_B = getTensorRef( + reordered_tensor_B, problem_size, "b", self.operation.B.layout) + cutlass.gemm.host.reorder_column( + tensor_ref_B, reordered_tensor_ref_B, problem_size) + return reordered_tensor_B + + def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): + # TODO + tensor_D_ref = np.ones_like(tensor_C) + alpha = self.numpy_type(self.compute_type)(alpha) + beta = self.numpy_type(self.compute_type)(beta) + init_acc = 0 + + alpha = self.compute_type(alpha).value() + beta = self.compute_type(beta).value() + init_acc = self.accumulator_type(init_acc).value() + + if self.operation.switched: + tensor_ref_A = getTensorRef( + tensor_A, problem_size, "a", transpose(self.operation.B.layout)) + tensor_ref_B = getTensorRef( + tensor_B, problem_size, "b", transpose(self.operation.A.layout)) + tensor_ref_C = getTensorRef( + tensor_C, problem_size, "c", transpose(self.operation.C.layout)) + tensor_ref_D_ref = getTensorRef( + tensor_D_ref, problem_size, "d", transpose(self.operation.C.layout)) + else: + tensor_ref_A = getTensorRef( + tensor_A, problem_size, "a", self.operation.A.layout) + tensor_ref_B = getTensorRef( + tensor_B, problem_size, "b", self.operation.B.layout) + tensor_ref_C = getTensorRef( + tensor_C, problem_size, "c", self.operation.C.layout) + tensor_ref_D_ref = getTensorRef( + tensor_D_ref, problem_size, "d", self.operation.C.layout) + + if self.math_operation in [MathOperation.multiply_add_saturate]: + cutlass.test.gemm.host.gemm_saturate( + problem_size, alpha, tensor_ref_A, tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + else: + cutlass.test.gemm.host.gemm(problem_size, alpha, tensor_ref_A, + tensor_ref_B, beta, tensor_ref_C, tensor_ref_D_ref, init_acc) + + return tensor_D_ref + + def equal(self, tensor_D, tensor_D_ref, problem_size): + + tensor_view_D = getTensorView( + tensor_D, problem_size, "d", self.operation.C.layout) + tensor_view_D_ref = getTensorView( + tensor_D_ref, problem_size, "d", self.operation.C.layout) + + return cutlass.test.gemm.host.equals(tensor_view_D, tensor_view_D_ref) + + def bytes(self, problem_size, batch_count=1, alpha=1.0, beta=0.0): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + bytes = \ + (DataTypeSize[self.operation.A.element] * m // 8) * k + \ + (DataTypeSize[self.operation.B.element] * n // 8) * k + \ + (DataTypeSize[self.operation.C.element] * m // 8) * n + + if beta != 0: + bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n + + bytes *= batch_count + + return bytes + + def flops(self, problem_size, batch_count=1): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + flops_ = (m * n * k + m * n) * 2 * batch_count + + # TODO: complex + return flops_ + + def run_cutlass_profiler(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0): + + cutlass_path = os.getenv('CUTLASS_PATH') + assert cutlass_path is not None, "Environment variable 'CUTLASS_PATH' is not defined." + + values = { + "profiler_path": cutlass_path + "/build/tools/profiler/cutlass_profiler", + "kernel_name": self.operation.procedural_name(), + "verification_providers": "device", + "provider": "cutlass", + "m": str(problem_size.m()), + "n": str(problem_size.n()), + "k": str(problem_size.k()), + 'split_k_slices': str(batch_count), + 'alpha': str(alpha), + 'beta': str(beta), + 'warmup': str(self.warmup_iterations), + 'profile': str(self.iterations) + } + + cmd_template = \ + "${profiler_path} --kernels=${kernel_name} --verification-providers=${verification_providers}" \ + " --providers=${provider} --m=${m} --n=${n} --k=${k}" + + cmd = SubstituteTemplate(cmd_template, values) + result = subprocess.getoutput(cmd) + + m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) + runtime = float(m.group('runtime')) + + m = re.search(r"Bytes:\s+(?P\d+)", result) + bytes = int(m.group('bytes')) + + m = re.search(r"FLOPs:\s+(?P\d+)", result) + flops = int(m.group('flops')) + + # check if the problem size matches + assert bytes == self.bytes(problem_size, alpha, beta) + assert flops == self.flops(problem_size) + + return runtime + + def run(self, mode, problem_size, batch_count=1, alpha=1.0, beta=0.0): + + assert get_allocated_size( + ) == 0, "%d byte of pool memory is not released in previous run" % get_allocated_size() + + np.random.seed(self.seed) + + tensor_A = self.uniform_init( + size=(problem_size.m() * problem_size.k(),), dtype=self.dtype_A) + tensor_B = self.uniform_init( + size=(problem_size.n() * problem_size.k(),), dtype=self.dtype_B) + tensor_C = self.uniform_init( + size=(problem_size.m() * problem_size.n(),), dtype=self.dtype_C) + tensor_D = np.zeros( + shape=(problem_size.m() * problem_size.n(),), dtype=self.dtype_D) + + # + # Launch kernel + # + + arguments = GemmArguments( + operation=self.operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=LinearCombinationFunctorArguments(alpha, beta), + gemm_mode=mode, split_k_slices=batch_count + ) + + if mode == cutlass.gemm.Mode.GemmSplitKParallel: + reduction_arguments = ReductionArguments( + self.reduction_operation, problem_size=[ + problem_size.m(), problem_size.n()], + partitions=batch_count, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op=LinearCombinationFunctorArguments(alpha, beta) + ) + + self.operation.run(arguments) + + if mode == cutlass.gemm.Mode.GemmSplitKParallel: + self.reduction_operation.run(reduction_arguments) + + passed = True + + if self.verification: + if mode == cutlass.gemm.Mode.GemmSplitKParallel: + reduction_arguments.sync() + else: + arguments.sync() + tensor_D_ref = self.host_reference( + problem_size, tensor_A, tensor_B, tensor_C, alpha, beta) + passed = self.equal(tensor_D, tensor_D_ref, problem_size) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size, mode, batch_count) + + if self.profiling: + sleep(self.sleep_time) + for _ in range(self.warmup_iterations): + self.operation.run(arguments) + if mode == cutlass.gemm.Mode.GemmSplitKParallel: + self.reduction_operation.run(reduction_arguments) + + self.timer.start() + for _ in range(self.iterations): + self.operation.run(arguments) + if mode == cutlass.gemm.Mode.GemmSplitKParallel: + self.reduction_operation.run(reduction_arguments) + self.timer.stop_and_wait() + + runtime = self.timer.duration(self.iterations) + + # free memory and clear buffers + del arguments + if mode == cutlass.gemm.Mode.GemmSplitKParallel: + del reduction_arguments + + assert get_allocated_size( + ) == 0, "%d byte of pool memory is not released after current run" % get_allocated_size() + + if self.profiling: + return runtime + return passed + + +def test_all_gemm(operation: 'GemmOperationUniversal', testcase="universal"): + + passed = True + + minimum_operand_element_size = min( + DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) + opcode_class = operation.tile_description.math_instruction.opcode_class + + if opcode_class == cutlass.OpClass.Simt: + alignment = 1 + else: + alignment = 128 // minimum_operand_element_size + + # int8_t gemm alignment constrainst + if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 and operation.A.layout == cutlass.ColumnMajor: + alignment_m = 4 + else: + alignment_m = alignment + + if opcode_class == cutlass.OpClass.Simt and operation.B.element == cutlass.int8 and operation.A.layout == cutlass.RowMajor: + alignment_n = 4 + else: + alignment_n = alignment + + if opcode_class == cutlass.OpClass.Simt and operation.A.element == cutlass.int8 \ + and operation.B.element == cutlass.int8 \ + and (operation.A.layout == cutlass.RowMajor or operation.B.layout == cutlass.ColumnMajor): + + alignment_k = 4 + else: + alignment_k = alignment + + threadblock_k = operation.tile_description.threadblock_shape[2] + + if testcase == "interleaved": + if operation.A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]: + interleavedk = 32 + else: + raise ValueError("unknonw layout") + + if testcase == "interleaved": + modes = [cutlass.gemm.Mode.Gemm, ] + problem_size_m = [interleavedk, 512+interleavedk] + problem_size_n = [interleavedk, 512+interleavedk] + problem_size_k = [interleavedk, threadblock_k * + operation.tile_description.stages + interleavedk] + problem_alpha = [1.0] + problem_beta = [0.0] + batch_counts = [1, ] + elif testcase == "multistage": + modes = [cutlass.gemm.Mode.Gemm, ] + problem_size_m = [16, 528] + problem_size_n = [16, 528] + problem_size_k = [threadblock_k, threadblock_k * operation.tile_description.stages + + operation.tile_description.math_instruction.instruction_shape[2]] + problem_alpha = [1.0] + problem_beta = [0.0] + batch_counts = [1, ] + else: # universal + modes = [cutlass.gemm.Mode.Gemm, cutlass.gemm.Mode.GemmSplitKParallel] + problem_size_m = [alignment_m, 512 - 3 * alignment_m] + problem_size_n = [alignment_n, 512 - 2 * alignment_n] + problem_size_k = [ + alignment_k, + threadblock_k * operation.tile_description.stages - alignment_k, + threadblock_k * operation.tile_description.stages * 3 - alignment_k] + batch_counts = [1, 2, 3, 5, 7] + problem_alpha = [1.0] + problem_beta = [2.0] + + testbed = GemmUniversalLauncher( + operation, interleaved=(testcase == "interleaved")) + + for mode in modes: + for m in problem_size_m: + for n in problem_size_n: + for k in problem_size_k: + for batch_count in batch_counts: + for alpha in problem_alpha: + for beta in problem_beta: + # skip very small K problems + if testcase == "universal": + if (k // batch_count < 2 * threadblock_k): + continue + + problem_size = cutlass.gemm.GemmCoord(m, n, k) + + passed = testbed.run( + mode, problem_size, batch_count, alpha, beta) + + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError( + "CUDA Error %s" % str(err)) + + if not passed: + return False + + return passed diff --git a/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py b/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py new file mode 100644 index 00000000..7738dd72 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py @@ -0,0 +1,70 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cuda import cuda +from cuda import cudart + + +class GpuTimer: + def __init__(self) -> None: + self.events = [ + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1] + ] + + def start(self, stream=cuda.CUstream(0)): + err, = cuda.cuEventRecord(self.events[0], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + def stop(self, stream=cuda.CUstream(0)): + err, = cuda.cuEventRecord(self.events[1], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + pass + + def stop_and_wait(self, stream=cuda.CUstream(0)): + self.stop(stream) + if stream: + err, = cuda.cuStreamSynchronize(stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + else: + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + def duration(self, iterations=1): + err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1]) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + return duration / float(iterations) diff --git a/tools/library/scripts/pycutlass/src/pycutlass/type.py b/tools/library/scripts/pycutlass/src/pycutlass/type.py new file mode 100644 index 00000000..771a8c3e --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/type.py @@ -0,0 +1,39 @@ +################################################################################ +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +from typing import Union +from typeguard import typechecked + + +GemmOperation = 'Union[GemmOperationUniversal, GemmOperationGrouped]' + +Tensor = 'Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]' diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py b/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py new file mode 100644 index 00000000..70e44b18 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py @@ -0,0 +1 @@ +from pycutlass.utils.reference_model import * diff --git a/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py b/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py new file mode 100644 index 00000000..809fcf99 --- /dev/null +++ b/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py @@ -0,0 +1,234 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import numpy as np +import cutlass +from pycutlass.library import TensorDescription +from typing import Union +try: + import torch + torch_available = True +except ImportError: + torch_available = False + +class ReferenceModule: + def __init__(self, A: TensorDescription, B: TensorDescription, C: TensorDescription) -> None: + self.layout_A = A.layout + self.layout_B = B.layout + self.layout_C = C.layout + + def run(self, A: np.ndarray, B: np.ndarray, C: np.ndarray, problem_size: cutlass.gemm.GemmCoord, alpha: float=1.0, beta: float=0.0): + """ + Compute the reference result on CPU + Args: + A: dense operator with shape (M, K) in row-major and (K, M) in column-major + B: dense operator with shape (K, N) in row-major and (N, K) in column-major + C: dense operator with shape (M, N) in row-major and (N, M) in column-major + """ + M, N, K = problem_size.m(), problem_size.n(), problem_size.k() + if isinstance(A, np.ndarray): + if self.layout_A == cutlass.RowMajor: + A_row = np.reshape(A, newshape=(M, K)) + else: + A_col = np.reshape(A, newshape=(K, M)) + A_row = np.transpose(A_col, axes=(1, 0)) + + if self.layout_B == cutlass.RowMajor: + B_row = np.reshape(B, newshape=(K, N)) + else: + B_col = np.reshape(B, newshape=(N, K)) + B_row = np.transpose(B_col, axes=(1, 0)) + + if self.layout_C == cutlass.RowMajor: + C_row = np.reshape(C, newshape=(M, N)) + else: + C_col = np.reshape(C, newshape=(N, M)) + C_row = np.transpose(C_col, axes=(1, 0)) + + out_row = np.matmul(A_row, B_row) * alpha + C_row * beta + + if self.layout_C == cutlass.ColumnMajor: + out = np.transpose(out_row, axes=(1, 0)) + else: + out = out_row + + return out.ravel() + + elif isinstance(A, torch.Tensor): + if self.layout_A == cutlass.RowMajor: + A_row = A.view((M, K)) + else: + A_col = A.view((K, M)) + A_row = torch.permute(A_col, (1, 0)) + + if self.layout_B == cutlass.RowMajor: + B_row = B.view((K, N)) + else: + B_col = B.view((N, K)) + B_row = torch.permute(B_col, (1, 0)) + + if self.layout_C == cutlass.RowMajor: + C_row = C.view((M, N)) + else: + C_col = C.view((N, M)) + C_row = torch.permute(C_col, (1, 0)) + + out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta + + if self.layout_C == cutlass.ColumnMajor: + out = torch.permute(out_row, (1, 0)) + else: + out = out_row + + return torch.flatten(out) + + + +##################################################################################################### +# Conv2d +##################################################################################################### + +if torch_available: + class Conv2dReferenceModule: + def __init__(self, A: TensorDescription, B: TensorDescription, C: TensorDescription, kind: cutlass.conv.Operator.fprop) -> None: + self.layout_A = A.layout + self.layout_B = B.layout + self.layout_C = C.layout + self.kind = kind + + def run(self, + A: Union[np.ndarray, torch.Tensor], + B: Union[np.ndarray, torch.Tensor], + C: Union[np.ndarray, torch.Tensor], problem_size, alpha=1.0, beta=0.0) -> np.ndarray: + """ + Compute the reference result on CPU + """ + n = problem_size.N + h = problem_size.H + w = problem_size.W + c = problem_size.C + + k = problem_size.K + r = problem_size.R + s = problem_size.S + + p = problem_size.P + q = problem_size.Q + + stride_h = problem_size.stride_h + stride_w = problem_size.stride_w + + pad_h = problem_size.pad_h + pad_w = problem_size.pad_w + + dilation_h = problem_size.dilation_h + dilation_w = problem_size.dilation_w + + groups = problem_size.groups + + if isinstance(A, np.ndarray): + # the pytorch activation layout is NCHW + # weight layout is Cout Cin Kh Kw (also NCHW) + if self.layout_A == cutlass.TensorNHWC: + A_nhwc = np.reshape(A, newshape=(n, h, w, c)) + A_torch_nhwc = torch.from_numpy(A_nhwc).to("cuda") + A_torch_nchw = torch.permute(A_torch_nhwc, (0, 3, 1, 2)) + + if self.layout_B == cutlass.TensorNHWC: + B_nhwc = np.reshape(B, newshape=(k, r, s, c)) + B_torch_nhwc = torch.from_numpy(B_nhwc).to("cuda") + B_torch_nchw = torch.permute(B_torch_nhwc, (0, 3, 1, 2)) + + if self.layout_C == cutlass.TensorNHWC: + C_nhwc = np.reshape(C, newshape=(n, p, q, k)) + C_torch_nhwc = torch.from_numpy(C_nhwc).to("cuda") + C_torch_nchw = torch.permute(C_torch_nhwc, (0, 3, 1, 2)) + + elif isinstance(A, torch.Tensor): + if self.kind == cutlass.conv.Operator.wgrad: + if self.layout_A == cutlass.TensorNHWC: + A_nhwc = A.view((n, p, q, k)) + A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2)) + + if self.layout_B == cutlass.TensorNHWC: + B_nhwc = B.view((n, h, w, c)) + B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) + + if self.layout_C == cutlass.TensorNHWC: + C_nhwc = C.view((k, r, s, c)) + C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) + elif self.kind == cutlass.conv.Operator.dgrad: + if self.layout_A == cutlass.TensorNHWC: + A_nhwc = A.view((n, p, q, k)) + A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2)) + + if self.layout_B == cutlass.TensorNHWC: + B_nhwc = B.view((k, r, s, c)) + B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) + + if self.layout_C == cutlass.TensorNHWC: + C_nhwc = C.view((n, h, w, c)) + C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) + else: + if self.layout_A == cutlass.TensorNHWC: + A_nhwc = A.view((n, h, w, c)) + A_torch_nchw = torch.permute(A_nhwc, (0, 3, 1, 2)) + + if self.layout_B == cutlass.TensorNHWC: + B_nhwc = B.view((k, r, s, c)) + B_torch_nchw = torch.permute(B_nhwc, (0, 3, 1, 2)) + + if self.layout_C == cutlass.TensorNHWC: + C_nhwc = C.view((n, p, q, k)) + C_torch_nchw = torch.permute(C_nhwc, (0, 3, 1, 2)) + + if self.kind == cutlass.conv.Operator.fprop: + D_torch_nchw = alpha * torch.nn.functional.conv2d( + A_torch_nchw, B_torch_nchw, stride=(stride_h, stride_w), + padding=(pad_h, pad_w), dilation=(dilation_h, dilation_w), groups=groups) + beta * C_torch_nchw + elif self.kind == cutlass.conv.Operator.dgrad: + D_torch_nchw = alpha * torch.nn.grad.conv2d_input( + (n, c, h, w), B_torch_nchw, A_torch_nchw, padding=(pad_h, pad_w), stride=(stride_h, stride_w) + ).to(torch.float32) + beta * C_torch_nchw + elif self.kind == cutlass.conv.Operator.wgrad: + D_torch_nchw = alpha * torch.nn.grad.conv2d_weight( + B_torch_nchw, (k, c, r, s), A_torch_nchw, padding=(pad_h, pad_w), stride=(stride_h, stride_w) + ).to(torch.float32) + beta * C_torch_nchw + + + if self.layout_C == cutlass.TensorNHWC: + if isinstance(A, np.ndarray): + D_torch_out = torch.permute(D_torch_nchw, (0, 2, 3, 1)).detach().cpu().numpy() + elif isinstance(A, torch.Tensor): + D_torch_out = torch.permute(D_torch_nchw, (0, 2, 3, 1)) + + return D_torch_out.flatten() diff --git a/tools/library/scripts/pycutlass/test/conv/__init__.py b/tools/library/scripts/pycutlass/test/conv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/library/scripts/pycutlass/test/conv/cached_results_SM80.txt b/tools/library/scripts/pycutlass/test/conv/cached_results_SM80.txt new file mode 100644 index 00000000..91cbe531 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/cached_results_SM80.txt @@ -0,0 +1,274 @@ +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1767700736 2104699940 3506659864 557648934 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1539314507 3971227455 1976927351 1642148785 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 276489656 653235219 3147305346 880610205 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 272457724 2178229139 2786201726 4170295839 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 242235041 2149454506 784935854 682531065 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 3478189705 1667216236 1437761176 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 379326961 1780379994 3740415776 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 924848818 3533854396 2683779476 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 359232443 2147867990 1653277018 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 3784314846 2644315999 4224154526 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3787448414 3562991793 535073859 2563373454 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 426169840 2464808416 864648234 461884698 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2564934525 3910792915 3577331017 827498183 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 28479234 867695528 1947311971 83328334 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4192922822 4244595864 2296602326 2349214706 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 274678245 3464152269 1682550229 3446204619 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3993280136 828543035 1319748516 956044554 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 832003025 3799813757 4030292245 457791957 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1444316594 4129865888 93616503 412257611 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 36703874 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 1842147148 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1612565294 109894479 1782187316 3370789453 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 841569299 1010785577 1158956167 3261208135 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1893352157 48149942 3544807462 446577726 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 3585320147 2150950452 1625817025 3964129474 +conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 289918791 2624928614 3423533117 3186342135 +conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 2732296888 1838622641 4203745561 +conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3456572634 893492926 1966259884 +conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 671982235 4014726279 4027869577 1510990157 +conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 798317794 4140605332 3580988556 3425909428 +conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1721270411 2106553169 835800311 3417471222 +conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 860217059 166776702 1109666471 +conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2128738105 855244826 2670006594 3857976152 +conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1931093565 3079461262 3579256638 2926210806 +conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2472246681 2952423142 2045838875 3445165841 +conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2956871200 2133381336 2601441527 2035094220 +conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 365467186 1700915522 2515933441 406719240 +conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 3347784734 156533442 1012781676 688128904 +conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 927718585 3117803557 1370701307 1462167731 +conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 973422497 1926250028 3440543762 +conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 2892862516 3649300762 1521470286 +conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2075083065 3181416651 1733426984 872275640 +conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4005590448 1639170045 388151578 4186957447 +conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 181075276 1433744686 860506550 3475157408 +conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1747719409 877465841 2345541783 +conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 856324887 2307248012 337386755 3363072703 +conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1906605830 722034901 2562804622 2508759317 +conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 805717279 2196645331 3235235362 1518334120 +conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3168796339 72559978 778918419 1260968000 +conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 261954979 2634885882 451986822 3792829599 +conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 3747142491 2426759809 2622222681 371723930 +conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2056905385 3612826298 2531545294 476754549 +conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 2391975923 197605094 3409942185 +conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3071904063 408984565 2378809888 +conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3414629540 3067676760 1540919649 2008865071 +conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4100326666 1085505037 2778215386 230227569 +conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3662895757 2731079464 3570839563 3483629877 +conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 408419601 3415600242 2106927195 +conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2154102133 3606099389 4034802752 3200055633 +conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2609259399 3910244699 1319285699 2229775542 +conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2948772873 2780071616 2703730845 3090625734 +conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 752289976 4278696824 360883914 3802692600 +conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3723912751 653419877 359675571 283806385 +conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 2027599472 1075980921 3101013494 2025203940 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 991402150 1393431534 1148212814 1350914659 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 4283492776 419570292 1210341563 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4178596783 3828059710 2735749436 2671012171 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 924522595 563724475 3750778972 4152580670 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1021044158 1686067905 3765040166 4102272733 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 2674994719 635224486 2759329777 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 4201252830 2920298728 304256151 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 70289262 646435722 4137562540 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 1288095320 2132879813 656196754 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 2202157489 2326567490 2475188414 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2476454437 1857118302 4164386062 239840568 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2767650699 3514840131 590439733 3879821123 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3896287283 3112762669 2515107934 2106635937 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1903067870 1021832870 3003938078 2751931686 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3489785028 2466126497 1374078692 2737628040 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2051350923 263676708 3639860119 1370886256 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 719099834 1474713672 204857540 2768940347 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3441724486 3162593831 421721594 3097845598 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2034354027 1249407570 2567025479 1441082595 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 2369653089 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 1218705038 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 172579142 319546523 718795680 1453661415 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2823351660 1326352711 1110204809 1155441703 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3238446487 2572503545 686287700 1559476701 +conv2d fprop_1x8x8x1_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 991402150 1883874274 1180207512 3934800419 +conv2d fprop_1x16x16x1_8x8_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 24290453 4230587034 4117433929 2540623821 +conv2d fprop_1x16x16x1_12x12_16x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 24290453 3802993432 1563447158 515257167 +conv2d fprop_1x224x224x1_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 2583340103 3928463259 1564251818 +conv2d fprop_1x224x224x1_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 2966178620 3457283045 1726663817 +conv2d fprop_1x224x224x1_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 1794561978 3101289788 3492498648 +conv2d fprop_1x224x224x1_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 7656882 1794561978 498358130 4111289929 +conv2d fprop_1x8x8x2_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2693144988 3876248534 3038023830 1910263513 +conv2d fprop_1x16x16x2_8x8_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 3355193355 319259163 535683577 +conv2d fprop_1x16x16x2_12x12_16x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 1548147432 3385829172 2741952709 +conv2d fprop_1x224x224x2_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 2686562907 3948710179 3669872932 +conv2d fprop_1x224x224x2_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 576815792 2317227037 1211532666 +conv2d fprop_1x224x224x2_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 27596985 555460201 895685163 +conv2d fprop_1x224x224x2_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3479872296 27596985 1465341652 2228916523 +conv2d fprop_1x8x8x4_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 24290453 137535877 1436667267 1395660627 +conv2d fprop_1x224x224x4_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 2226159049 4051661898 209529384 +conv2d fprop_1x224x224x4_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 3541851870 2271016226 2671623385 +conv2d fprop_1x224x224x4_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 982184919 2007343215 3362992769 +conv2d fprop_1x224x224x4_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2495921302 982184919 20610297 1086800078 +conv2d fprop_1x8x8x8_4x4_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 3117444553 1497663382 3561001103 +conv2d fprop_1x224x224x8_220x220_32x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 1414143072 827338392 2827855918 +conv2d fprop_1x224x224x8_110x110_64x7x7_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 3886996022 26545788 3407771964 +conv2d fprop_1x224x224x8_222x222_64x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 380272816 2374613655 3601677176 +conv2d fprop_1x224x224x8_111x111_64x5x5_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3188907679 380272816 778374730 2110111988 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1736512560 49406874 846358010 3314905564 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1848484956 1432417472 1903569827 3750799351 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4236427320 3696009469 69852620 201921851 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 109006944 450017448 1793784844 903209915 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 813367872 2397796503 1928191746 3210229460 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1307184141 46021356 1674017987 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1212511562 3331767121 2446286369 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 2013675943 1681111033 1469213228 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 500298386 3218034344 4159283207 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 1123534155 145385311 4273847179 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3862659311 349459322 1503631520 1404971956 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1623686755 961217371 552550209 3980749384 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3554927580 1131648083 4149599295 3119557776 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1767639287 3350675774 128324027 1059816532 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3986143536 17411088 40173029 1694092310 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1157793540 3513299281 48848814 1435528367 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 988962069 4292634763 388976034 2674929544 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4202383208 3529769234 1046186503 3368902675 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 856448884 3057259762 2063087558 1995545427 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 400986166 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 1082696406 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2702905851 1992889713 731289041 608504198 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2742293143 4197915274 606840 3671124731 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 149434841 2288560511 2994968424 2881838300 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 2226824643 327135318 3718671210 2121176659 +conv2d fprop_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3254575292 1119957081 672831271 +conv2d fprop_1x4x4x14_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3115523958 3622905002 4020453928 3853387318 +conv2d fprop_1x23x56x98_10x22_128x3x3_pad_h4w5_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1702870033 1876930844 1190400523 3937287850 +conv2d fprop_1x4x4x28_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 2587856937 2021107274 2789519899 +conv2d fprop_1x23x56x100_10x22_128x3x3_pad_h4w5_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2368669977 1353376771 744357395 786349633 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 991402150 1393431534 2496492611 3901723984 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4208297221 4283492776 3148637036 258220505 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4178596783 3828059710 281106520 1103939403 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 924522595 563724475 1938163814 2197809394 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1021044158 1686067905 350851834 3999808950 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3335547 2674994719 1034822169 1611033520 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3335547 4201252830 1597212204 2181492560 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3335547 70289262 3001492060 1379239000 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1317457392 1288095320 4211138051 2804617605 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1317457392 2202157489 1043108884 2923122465 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2476454437 1857118302 3877008798 1206012078 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2767650699 3514840131 2946529611 3907056932 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3896287283 3112762669 1581171257 3959460786 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1903067870 1021832870 1926804094 1756790353 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3489785028 2466126497 1712378956 434322965 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2051350923 263676708 355203300 821870356 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 719099834 1474713672 2886387159 4086314983 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3441724486 3162593831 1422796372 2049419539 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2034354027 1249407570 1196036582 2684312264 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 941893937 3608468045 2198911423 1060050551 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 941893937 3608468045 2198911423 3361618746 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 172579142 319546523 2332616929 543467298 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2823351660 1326352711 3839068434 65031397 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3238446487 2572503545 3604065639 2111204111 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 2149247508 1775375365 2663631601 1249487679 +conv2d fprop_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 403997062 1679063623 4062928786 +conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 3464637181 1623218578 436154205 +conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1479940693 3253144559 3883419107 +conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 1871463331 2425320272 74566211 +conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3484040069 664160900 3610888033 22347127 +conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1924855848 1382111427 2541177413 +conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 868180534 1764715518 3070473696 2392864704 +conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3437976747 666906244 3401957738 2050602745 +conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4195072693 1575210381 781892324 2848949054 +conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3457330201 2316839359 1539389419 4293781748 +conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 2693098375 +conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 1969608051 +conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1690216859 554790212 2885143346 780489333 +conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 835105643 3337423971 3866137775 +conv2d dgrad_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2956180805 1092015789 3160693693 1526395881 +conv2d dgrad_1x56x56x12_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 1941683430 2236679600 3168985259 +conv2d dgrad_1x55x55x12_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 1941683430 3784328837 471971363 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 289918791 1266976707 942688231 3457364823 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1027662440 2005082293 2235558527 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3380032042 1370040310 1348846927 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 671982235 1423304149 2107662762 1234913781 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 798317794 1709026638 2421185623 3308071321 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1721270411 2519327328 2541413264 3185574975 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2070174510 1364436192 3531942595 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2128738105 2056902987 3079166829 2329433528 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3227877956 645422556 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3817218800 985231315 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 1398036015 3630062764 2492522537 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2784049299 643733019 3649549642 2637869234 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2332160299 302086821 3303132343 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1931093565 2458714707 2919710256 2311575036 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2472246681 2260022344 500095455 2760458995 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1530672622 3635363851 2402907878 4131497953 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1500864134 2536338700 2459524764 2504484273 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3344871528 2667385029 2714805835 3487838445 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 966721255 1547169349 3198573835 302049294 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 1317923157 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 3186679687 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4028893260 4220759192 2236533218 3731336532 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2956871200 1591352238 1756650151 1262787222 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 365467186 892422645 1334708242 1372556938 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 3347784734 150035460 2897171548 3701081496 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 927718585 4106152802 2634710231 744755886 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 3464637181 2709881923 2407415563 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 3723472741 3733128758 3129111191 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2075083065 2042513140 253288229 404121198 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4005590448 1116254439 525487530 3284739065 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 181075276 1743485155 91136873 2508716910 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 386662952 1127709182 4026285141 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 856324887 3954249564 2591894666 2655687700 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1300426008 1263618595 1313664339 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1300426008 1756414462 2995557277 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 447261065 121940906 1497499264 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3484040069 2966693627 1423016429 341928547 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1759979610 2761559427 68093525 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1906605830 2980501720 1650970502 3258883197 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 805717279 3502822733 3985958544 2568949300 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 868180534 3289288595 385631111 328914986 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3437976747 3391080565 1513955316 1521294163 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4195072693 1669352457 2608107448 4284090805 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3457330201 1126870455 106232038 3054809396 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 1723074453 1186911503 4239438967 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 1723074453 1186911503 2113601884 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1690216859 2413490039 36034283 1112346965 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3168796339 1601750164 14375779 2894970748 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 261954979 1300976652 4259930640 305685205 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 3747142491 1747587481 4137156526 1174257270 +conv2d wgrad_1x4x4x12_1x1_8x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2956180805 1086820986 1644914756 2013471312 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2056905385 447674669 724481645 1457430910 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 1227883689 3401425854 3897766524 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3749787834 3350064812 1136116240 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3414629540 820341033 770836461 2451581199 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4100326666 2581696511 1088458082 1521190911 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3662895757 2885454895 935600441 2615245898 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 3831334389 3506139121 814982501 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2154102133 737968461 1291834254 2665225480 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 1809195644 1765637461 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 3379808294 483095299 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 4194153035 2863868771 1639389008 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2624318208 157618421 1779474147 814087242 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 2300180628 423968553 3890279569 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2609259399 1848932917 522753581 1926508271 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2948772873 3663040534 4014266327 1288646188 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3271403719 1585195072 1487505772 3253374264 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1419588777 451194147 3578359696 3659768981 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 763924990 2780826684 2883769406 148530958 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2578426561 3849874822 102765469 1305171059 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1516344656 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1586331550 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2462511240 2274021368 1188866747 3178890497 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 752289976 1226457131 4187777346 1400559240 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3723912751 1585959358 3731079159 1498901684 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 2027599472 2758666204 3287095476 4291916486 +conv2d wgrad_1x8x8x1_8x8_1x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1767700736 4278264698 2331753571 2554564568 +conv2d dgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 927718585 3117803557 1370701307 1462167731 +conv2d dgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 973422497 1926250028 3440543762 +conv2d dgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 2892862516 3649300762 1521470286 +conv2d dgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 2075083065 3181416651 1733426984 872275640 +conv2d dgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4005590448 1639170045 388151578 4186957447 +conv2d dgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 181075276 1433744686 860506550 3475157408 +conv2d dgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1747719409 877465841 2345541783 +conv2d dgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 856324887 2307248012 337386755 3363072703 +conv2d dgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1906605830 722034901 2562804622 2508759317 +conv2d dgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 805717279 2196645331 3235235362 1518334120 +conv2d dgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3168796339 72559978 778918419 1260968000 +conv2d dgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 261954979 2634885882 451986822 3792829599 +conv2d dgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_fnhwc_f_f 3747142491 2426759809 2622222681 371723930 diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py new file mode 100644 index 00000000..fd3309bc --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -0,0 +1,187 @@ +# test/unit/conv/device/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +from pycutlass.conv2d_operation import * +from pycutlass import * +from pycutlass.test import * +import unittest + + +class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..bb8eff46 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,162 @@ +# test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage3(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=4, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage3_64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage4_64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=4, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py new file mode 100644 index 00000000..50cb2598 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -0,0 +1,89 @@ +# test/unit/conv/device/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu +import pycutlass +from pycutlass.conv2d_operation import * +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): + math_inst = MathInstruction( + instruction_shape=[1, 1, 1], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=1) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 8], stages=4, + warp_count=[4, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): + math_inst = MathInstruction( + instruction_shape=[1, 1, 1], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=1) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 8], stages=4, + warp_count=[2, 4, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..ea3ba2b0 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,86 @@ +# test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Unity, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..7a8e8ba3 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,154 @@ +# test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass.test import * +import unittest + +def conv2d_few_channel_problemsizes(channels): + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 8, 8, channels), + cutlass.Tensor4DCoord(16, 3, 3, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 16, 16, channels), + cutlass.Tensor4DCoord(16, 3, 3, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 16, 16, channels), + cutlass.Tensor4DCoord(16, 7, 7, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(32, 7, 7, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(64, 7, 7, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(64, 5, 5, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(64, 5, 5, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + return problem_sizes + +class Conv2dFpropFewChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=2) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=2) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation, conv2d_few_channel_problemsizes(2))) + + def test_SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_1(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=1) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=1) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=2, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.few_channels, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation, conv2d_few_channel_problemsizes(1))) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..43c38c81 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,175 @@ +# test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass.test import * +import unittest + +def conv2d_fixed_channel_problemsizes(channels): + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 8, 8, channels), + cutlass.Tensor4DCoord(16, 3, 3, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(32, 7, 7, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(64, 7, 7, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(64, 5, 5, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 224, 224, channels), + cutlass.Tensor4DCoord(64, 5, 5, channels), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + return problem_sizes + +class Conv2dFpropFixedChannelsF16NHWCF16NHWCF16HNWCTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation, conv2d_fixed_channel_problemsizes(8))) + + def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation, conv2d_fixed_channel_problemsizes(4))) + + def test_SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=2) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=2) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.fixed_channels, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation, conv2d_fixed_channel_problemsizes(2))) + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py new file mode 100644 index 00000000..36640794 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -0,0 +1,291 @@ +# test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=2) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=2) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 14), + cutlass.Tensor4DCoord(8, 3, 3, 14), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 23, 56, 98), + cutlass.Tensor4DCoord(128, 3, 3, 98), + cutlass.Tensor4DCoord(4, 0, 5, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + + def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align2(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=2) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=2) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 14), + cutlass.Tensor4DCoord(8, 3, 3, 14), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 23, 56, 98), + cutlass.Tensor4DCoord(128, 3, 3, 98), + cutlass.Tensor4DCoord(4, 0, 5, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + + def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_align4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 28), + cutlass.Tensor4DCoord(8, 3, 3, 28), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 23, 56, 100), + cutlass.Tensor4DCoord(128, 3, 3, 100), + cutlass.Tensor4DCoord(4, 0, 5, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..a48cc22c --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,48 @@ +# test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py new file mode 100644 index 00000000..05d77052 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -0,0 +1,87 @@ +# test/unit/conv/device/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu +import pycutlass +from pycutlass.conv2d_operation import * +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): + math_inst = MathInstruction( + instruction_shape=[1, 1, 1], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=1) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 8], stages=4, + warp_count=[4, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle2 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): + math_inst = MathInstruction( + instruction_shape=[1, 1, 1], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=1) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 8], stages=4, + warp_count=[2, 4, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..4e1570c5 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,98 @@ +# test/unit/conv/device/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=2) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=2) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.fprop, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ) + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..4c69340f --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,235 @@ +# test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x256_64x3_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 256, 64], stages=3, + warp_count=[2, 4, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4_128x128_32x3_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + + def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.dgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.StridedDgradIdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 56, 56, 12), + cutlass.Tensor4DCoord(8, 1, 1, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 55, 55, 12), + cutlass.Tensor4DCoord(8, 1, 1, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(2, 2), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py new file mode 100644 index 00000000..370abcc5 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py @@ -0,0 +1,86 @@ +# test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase): + def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float16, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float16, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..6e9ed6c7 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,224 @@ +# test/unit/conv/device/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase): + def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_64x256_32x4_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=8) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=8) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[64, 256, 32], stages=3, + warp_count=[1, 4, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + + def test_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_align4(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 4, 4, 12), + cutlass.Tensor4DCoord(8, 3, 3, 12), + cutlass.Tensor4DCoord(0, 0, 0, 0), + cutlass.MatrixCoord(3, 3), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py new file mode 100644 index 00000000..f92fdfb1 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py @@ -0,0 +1,87 @@ +# test/unit/conv/device/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu +import pycutlass +from pycutlass.conv2d_operation import * +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): + math_inst = MathInstruction( + instruction_shape=[1, 1, 1], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=1) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 8], stages=4, + warp_count=[2, 4, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.analytic, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self): + math_inst = MathInstruction( + instruction_shape=[1, 1, 1], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.Simt, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=1) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 8], stages=4, + warp_count=[2, 4, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py new file mode 100644 index 00000000..e5520715 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py @@ -0,0 +1,98 @@ +# test/unit/conv/device/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase): + def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=4) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=4) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=8) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 16], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + self.assertTrue(test_all_conv2d(operation)) + + def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + A = TensorDescription( + element=math_inst.element_a, + layout=cutlass.TensorNHWC, + alignment=1) + B = TensorDescription( + element=math_inst.element_b, + layout=cutlass.TensorNHWC, + alignment=1) + C = TensorDescription( + element=cutlass.float32, + layout=cutlass.TensorNHWC, + alignment=4) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], stages=3, + warp_count=[2, 2, 1], + math_instruction=math_inst, + min_compute=80, max_compute=80 + ) + + operation = Conv2dOperation( + conv_kind=cutlass.conv.Operator.wgrad, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=80, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=cutlass.float32, stride_support=StrideSupport.Strided, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + problem_sizes = [ + cutlass.conv.Conv2dProblemSize( + cutlass.Tensor4DCoord(1, 8, 8, 1), + cutlass.Tensor4DCoord(1, 3, 3, 1), + cutlass.Tensor4DCoord(1, 1, 1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.MatrixCoord(1, 1), + cutlass.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + self.assertTrue(test_all_conv2d(operation, problem_sizes)) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/conv/run_all_tests.py b/tools/library/scripts/pycutlass/test/conv/run_all_tests.py new file mode 100644 index 00000000..39278be2 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/conv/run_all_tests.py @@ -0,0 +1,10 @@ +import pycutlass +import unittest +from pycutlass.memory_manager import * + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**32, 2**32) + loader = unittest.TestLoader() + tests = loader.discover('./', 'conv2d_*.py') + testRunner = unittest.runner.TextTestRunner() + testRunner.run(tests) diff --git a/tools/library/scripts/pycutlass/test/frontend/run_test.sh b/tools/library/scripts/pycutlass/test/frontend/run_test.sh new file mode 100644 index 00000000..67aa3de5 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/frontend/run_test.sh @@ -0,0 +1 @@ +CUPY_CACHE_DIR=./ python test_frontend.py diff --git a/tools/library/scripts/pycutlass/test/frontend/test_frontend.py b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py new file mode 100644 index 00000000..6e2ee256 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/frontend/test_frontend.py @@ -0,0 +1,136 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +## Test case for Pytorch +import pycutlass +import unittest +from pycutlass import * +import torch +import cupy as cp + + +class Test_Frontend(unittest.TestCase): + def setUp(self) -> None: + # + # define the cutlass operator + # + math_inst = MathInstruction( + [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, + cutlass.OpClass.Simt, MathOperation.multiply_add + ) + + tile_description = TileDescription( + [128, 128, 8], 4, [2, 4, 1], + math_inst, 80, 80 + ) + + A = TensorDescription( + cutlass.float32, cutlass.RowMajor, 1 + ) + + B = TensorDescription( + cutlass.float32, cutlass.RowMajor, 1 + ) + + C = TensorDescription( + cutlass.float32, cutlass.RowMajor, 1 + ) + + self.operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=cutlass.float32, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1 + ) + + pycutlass.compiler.add_module([self.operation,]) + + + def test_torch_frontend(self): + problem_size = cutlass.gemm.GemmCoord(512, 256, 128) + + tensor_A = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.k()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) + tensor_B = torch.ceil(torch.empty(size=(problem_size.k(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) + tensor_C = torch.ceil(torch.empty(size=(problem_size.m(), problem_size.n()), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) + tensor_D = torch.empty_like(tensor_C) + + + alpha = 1.0 + beta = 0.0 + + arguments = GemmArguments( + operation=self.operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=LinearCombinationFunctorArguments(alpha, beta), + gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 + ) + + self.operation.run(arguments) + + arguments.sync() + + tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C + + self.assertTrue(torch.equal(tensor_D, tensor_D_ref)) + + def test_cupy_frontend(self): + cp.cuda.set_allocator(rmm.rmm_cupy_allocator) + + problem_size = cutlass.gemm.GemmCoord(512, 256, 128) + + tensor_A = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.m(), problem_size.k()), dtype=cp.float32)) + tensor_B = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.k(), problem_size.n()), dtype=cp.float32)) + tensor_C = cp.ceil(cp.random.uniform(low=-8.5, high=7.5, size=(problem_size.m(), problem_size.n()), dtype=cp.float32)) + tensor_D = cp.ones_like(tensor_C) + + alpha = 1.0 + beta = 1.0 + + tensor_D_ref = alpha * tensor_A @ tensor_B + beta * tensor_C + + arguments = GemmArguments( + operation=self.operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=LinearCombinationFunctorArguments(alpha, beta), + gemm_mode=cutlass.gemm.Mode.Gemm, split_k_splices=1 + ) + + self.operation.run(arguments) + + arguments.sync() + + self.assertTrue(cp.array_equal(tensor_D, tensor_D_ref)) + + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**32, 2**32) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/__init__.py b/tools/library/scripts/pycutlass/test/gemm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py new file mode 100644 index 00000000..59bf9bb3 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py @@ -0,0 +1,93 @@ +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +from pycutlass.test.gemm_testbed import test_all_gemm + +class GemmBF16TensorOpSm80(unittest.TestCase): + def SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_64x128x64_32x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.bfloat16, element_b=cutlass.bfloat16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[64, 128, 64], + stages=4, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.bfloat16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.bfloat16, layout=cutlass.ColumnMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_bf16t_bf16t_bf16t_tensor_op_f32_128x256x64_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.bfloat16, element_b=cutlass.bfloat16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[64, 128, 32], + stages=6, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.bfloat16, layout=cutlass.RowMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.bfloat16, layout=cutlass.RowMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.bfloat16, layout=cutlass.RowMajor, + alignment=8 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "multistage")) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**24, 2**24) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py new file mode 100644 index 00000000..284ac928 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py @@ -0,0 +1,425 @@ +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +from pycutlass.test.gemm_testbed import test_all_gemm + + +class GemmF16Sm80(unittest.TestCase): + def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.BatchedIdentitySwizzle + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + direct_store=True + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f16n_f16n_f16t_tensor_op_f32_128x128x64_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 64], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f16n_f16n_f32n_tensor_op_f32_128x256x64_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 256, 64], + stages=3, warp_count=[2, 4, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f16n_f16n_f32t_tensor_op_f32_256x128x64_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[256, 128, 64], + stages=3, warp_count=[4, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f16n_f16t_f16t_tensor_op_f16_sliced_k_128x64x64_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 64, 64], + stages=3, warp_count=[2, 1, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float16 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_GemmUniversal_f16n_f16t_f32t_tensor_op_f32_64x64x32_32x32x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float16, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[64, 64, 32], + stages=10, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float16 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f16n_f16t_f32t_tensor_op_f32_256x128x64_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[256, 128, 64], + stages=3, warp_count=[4, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_test_SM80_Device_Gemm_f16t_f16n_f16t_tensor_op_f16_sliced_k_128x64x64_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 64, 64], + stages=3, warp_count=[2, 1, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_128x256x64_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 256, 64], + stages=3, warp_count=[2, 4, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.RowMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f16t_f16t_f32t_tensor_op_f32_128x256x64_64x64x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], + element_a=cutlass.float16, element_b=cutlass.float16, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 256, 64], + stages=3, warp_count=[2, 4, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + B = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**24, 2**24) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py new file mode 100644 index 00000000..20c39be3 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py @@ -0,0 +1,138 @@ +import pycutlass +from pycutlass import * +from pycutlass.memory_manager import get_allocated_size +from pycutlass.test import * +import unittest + +from pycutlass.test.gemm_testbed import test_all_gemm + + +class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase): + def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add_fast_bf16 + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=4 + ) + B = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + + def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_f32_128x128x32_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + B = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f32n_f32n_f32t_tensor_op_fast_accurate_f32_64x64x32_32x32x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 8], + element_a=cutlass.float32, element_b=cutlass.float32, + element_accumulator=cutlass.float32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add_fast_f32 + ) + + tile_description = TileDescription( + threadblock_shape=[64, 64, 32], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + B = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + C = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**24, 2**24) + pycutlass.compiler.load_from_cache() + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py new file mode 100644 index 00000000..04591ab2 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py @@ -0,0 +1,95 @@ +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +from pycutlass.test.gemm_testbed import test_all_gemm + +class GemmF64TensorOpSm80(unittest.TestCase): + def test_SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64_32x32x16_16x16x16(self): + math_inst = MathInstruction( + instruction_shape=[8, 8, 4], + element_a=cutlass.float64, element_b=cutlass.float64, + element_accumulator=cutlass.float64, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[32, 32, 16], + stages=4, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + # alignment 1 restricted for double + A = TensorDescription( + element=cutlass.float64, layout=cutlass.ColumnMajor, + alignment=1 + ) + B = TensorDescription( + element=cutlass.float64, layout=cutlass.RowMajor, + alignment=1 + ) + C = TensorDescription( + element=cutlass.float64, layout=cutlass.RowMajor, + alignment=1 + ) + + element_epilogue = cutlass.float64 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + + def test_SM80_Device_Gemm_f64t_f64n_f64t_tensor_op_f64_64x64x16_32x32x16(self): + math_inst = MathInstruction( + instruction_shape=[8, 8, 4], + element_a=cutlass.float64, element_b=cutlass.float64, + element_accumulator=cutlass.float64, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[64, 64, 16], + stages=4, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + # alignment 1 restricted for double + A = TensorDescription( + element=cutlass.float64, layout=cutlass.RowMajor, + alignment=1 + ) + B = TensorDescription( + element=cutlass.float64, layout=cutlass.ColumnMajor, + alignment=1 + ) + C = TensorDescription( + element=cutlass.float64, layout=cutlass.RowMajor, + alignment=1 + ) + + element_epilogue = cutlass.float64 + + epilogue_functor = EpilogueFunctor.LinearCombination + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "universal")) + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**24, 2**24) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py new file mode 100644 index 00000000..6024f83c --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py @@ -0,0 +1,197 @@ +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +from pycutlass.test.gemm_grouped_testbed import TestbedGrouped + + +class GemmGroupedSm80(unittest.TestCase): + def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x32(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], element_a=cutlass.float16, + element_b=cutlass.float16, element_accumulator=cutlass.float32, + opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + + B = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + + C = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + epilogue_functor = EpilogueFunctor.LinearCombination + swizzling_functor = cutlass.BatchedIdentitySwizzle + + for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: + operation = GemmOperationGrouped( + tile_description.minimum_compute_capability, + tile_description, A, B, C, + element_epilogue, + epilogue_functor, swizzling_functor, + precompute_mode=precompute_mode + ) + + testbed = TestbedGrouped(operation=operation) + + self.assertTrue(testbed.run(24)) + + def test_SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64_64x64x16_32x32x16(self): + math_inst = MathInstruction( + instruction_shape=[8, 8, 4], element_a=cutlass.float64, + element_b=cutlass.float64, element_accumulator=cutlass.float64, + opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[64, 64, 16], + stages=4, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float64, layout=cutlass.RowMajor, + alignment=1 + ) + + B = TensorDescription( + element=cutlass.float64, layout=cutlass.RowMajor, + alignment=1 + ) + + C = TensorDescription( + element=cutlass.float64, layout=cutlass.ColumnMajor, + alignment=1 + ) + + element_epilogue = cutlass.float64 + epilogue_functor = EpilogueFunctor.LinearCombination + swizzling_functor = cutlass.BatchedIdentitySwizzle + + for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: + operation = GemmOperationGrouped( + tile_description.minimum_compute_capability, + tile_description, A, B, C, + element_epilogue, + epilogue_functor, swizzling_functor, + precompute_mode=precompute_mode + ) + + testbed = TestbedGrouped(operation=operation) + + self.assertTrue(testbed.run(24)) + + def test_SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32_128x64x8_64x32x1(self): + math_inst = MathInstruction( + instruction_shape=[1, 1, 1], element_a=cutlass.float32, + element_b=cutlass.float32, element_accumulator=cutlass.float32, + opcode_class=cutlass.OpClass.Simt, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 64, 8], + stages=4, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=1 + ) + + B = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=1 + ) + + C = TensorDescription( + element=cutlass.float32, layout=cutlass.RowMajor, + alignment=1 + ) + + element_epilogue = cutlass.float32 + epilogue_functor = EpilogueFunctor.LinearCombination + swizzling_functor = cutlass.BatchedIdentitySwizzle + + for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: + operation = GemmOperationGrouped( + tile_description.minimum_compute_capability, + tile_description, A, B, C, + element_epilogue, + epilogue_functor, swizzling_functor, + precompute_mode=precompute_mode + ) + + testbed = TestbedGrouped(operation=operation) + + self.assertTrue(testbed.run(27)) + + def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x32_cache(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 16], element_a=cutlass.float16, + element_b=cutlass.float16, element_accumulator=cutlass.float32, + opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 32], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + + B = TensorDescription( + element=cutlass.float16, layout=cutlass.ColumnMajor, + alignment=8 + ) + + C = TensorDescription( + element=cutlass.float32, layout=cutlass.ColumnMajor, + alignment=4 + ) + + element_epilogue = cutlass.float32 + epilogue_functor = EpilogueFunctor.LinearCombination + swizzling_functor = cutlass.BatchedIdentitySwizzle + + for precompute_mode in [SchedulerMode.Device, SchedulerMode.Host]: + operation = GemmOperationGrouped( + tile_description.minimum_compute_capability, + tile_description, A, B, C, + element_epilogue, + epilogue_functor, swizzling_functor, + precompute_mode=precompute_mode + ) + + testbed = TestbedGrouped(operation=operation) + + self.assertTrue(testbed.run(5)) + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py new file mode 100644 index 00000000..b41b78fd --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py @@ -0,0 +1,219 @@ +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +from pycutlass.test.gemm_testbed import test_all_gemm + +class GemmS8TensorOpF32Sm80(unittest.TestCase): + def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_64x64x64_32x32x64(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 32], + element_a=cutlass.int8, element_b=cutlass.int8, + element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add_saturate + ) + + tile_description = TileDescription( + threadblock_shape=[64, 64, 64], + stages=6, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.int8, layout=cutlass.ColumnMajorInterleaved32, + alignment=16 + ) + B = TensorDescription( + element=cutlass.int8, layout=cutlass.RowMajorInterleaved32, + alignment=16 + ) + C = TensorDescription( + element=cutlass.int8, layout=cutlass.ColumnMajorInterleaved32, + alignment=8 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "interleaved")) + + def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_256x128x128_64x64x128(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 32], + element_a=cutlass.int8, element_b=cutlass.int8, + element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 128], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.int8, layout=cutlass.RowMajor, + alignment=16 + ) + B = TensorDescription( + element=cutlass.int8, layout=cutlass.ColumnMajor, + alignment=16 + ) + C = TensorDescription( + element=cutlass.int8, layout=cutlass.RowMajor, + alignment=16 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "multistage")) + + def test_SM80_Device_Gemm_s8t_s8n_s8n_tensor_op_s32_128x128x128_64x64x128(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 32], + element_a=cutlass.int8, element_b=cutlass.int8, + element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 128], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.int8, layout=cutlass.RowMajor, + alignment=16 + ) + B = TensorDescription( + element=cutlass.int8, layout=cutlass.ColumnMajor, + alignment=16 + ) + C = TensorDescription( + element=cutlass.int8, layout=cutlass.ColumnMajor, + alignment=16 + ) + + element_epilogue = cutlass.float32 + + epilogue_functor = EpilogueFunctor.FastLinearCombinationClamp + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "multistage")) + + def test_SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32_128x128x128_64x64x128(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 32], + element_a=cutlass.int8, element_b=cutlass.int8, + element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 128], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.int8, layout=cutlass.RowMajor, + alignment=16 + ) + B = TensorDescription( + element=cutlass.int8, layout=cutlass.ColumnMajor, + alignment=16 + ) + C = TensorDescription( + element=cutlass.int32, layout=cutlass.ColumnMajor, + alignment=4 + ) + + element_epilogue = cutlass.int32 + + epilogue_functor = EpilogueFunctor.LinearCombinationClamp + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "multistage")) + + def test_SM80_Device_Gemm_s8t_s8n_s32t_tensor_op_s32_128x128x128_64x64x128(self): + math_inst = MathInstruction( + instruction_shape=[16, 8, 32], + element_a=cutlass.int8, element_b=cutlass.int8, + element_accumulator=cutlass.int32, opcode_class=cutlass.OpClass.TensorOp, + math_operation=MathOperation.multiply_add + ) + + tile_description = TileDescription( + threadblock_shape=[128, 128, 128], + stages=3, warp_count=[2, 2, 1], + math_instruction=math_inst, min_compute=80, max_compute=80 + ) + + A = TensorDescription( + element=cutlass.int8, layout=cutlass.RowMajor, + alignment=16 + ) + B = TensorDescription( + element=cutlass.int8, layout=cutlass.ColumnMajor, + alignment=16 + ) + C = TensorDescription( + element=cutlass.int32, layout=cutlass.RowMajor, + alignment=4 + ) + + element_epilogue = cutlass.int32 + + epilogue_functor = EpilogueFunctor.LinearCombinationClamp + + swizzling_functor = cutlass.IdentitySwizzle1 + + operation = GemmOperationUniversal( + arch=80, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + + self.assertTrue(test_all_gemm(operation, "multistage")) + + + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**24, 2**24) + unittest.main() diff --git a/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py b/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py new file mode 100644 index 00000000..8a874446 --- /dev/null +++ b/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py @@ -0,0 +1,9 @@ +import pycutlass +import unittest + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**26, 2**26) + loader = unittest.TestLoader() + tests = loader.discover('./', 'gemm_*.py') + testRunner = unittest.runner.TextTestRunner() + testRunner.run(tests) diff --git a/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt b/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt new file mode 100644 index 00000000..c5e51d9f --- /dev/null +++ b/tools/library/scripts/pycutlass/test/unit/cached_results_SM80_2080.txt @@ -0,0 +1,350 @@ +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1767700736 2104699940 3506659864 557648934 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1539314507 3971227455 1976927351 1642148785 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 276489656 653235219 3147305346 880610205 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 272457724 2178229139 2786201726 4170295839 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 242235041 2149454506 784935854 682531065 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 3478189705 1667216236 1437761176 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 379326961 1780379994 3740415776 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 215903418 924848818 3533854396 2683779476 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 359232443 2147867990 1653277018 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2870331951 3784314846 2644315999 4224154526 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3787448414 3562991793 535073859 2563373454 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 426169840 2464808416 864648234 461884698 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2564934525 3910792915 3577331017 827498183 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 28479234 867695528 1947311971 83328334 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4192922822 4244595864 2296602326 2349214706 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 274678245 3464152269 1682550229 3446204619 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3993280136 828543035 1319748516 956044554 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 832003025 3799813757 4030292245 457791957 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1444316594 4129865888 93616503 412257611 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 36703874 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2931873718 1841508064 1497852219 1842147148 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1612565294 109894479 1782187316 3370789453 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 841569299 1010785577 1158956167 3261208135 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1893352157 48149942 3544807462 446577726 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 3585320147 2150950452 1625817025 3964129474 +conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 1227883689 3016005301 4142905842 +conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3337296764 4183699161 3654176452 +conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3852963969 864006170 920352568 +conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2624318208 2750240096 2120184232 2600672872 +conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 3224082300 2084034673 3588056946 +conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3271403719 3033073939 304048758 1882633089 +conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1419588777 610026473 447427404 2639856195 +conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 763924990 2818680871 58428273 3332443900 +conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2578426561 1891702153 103393067 2558647731 +conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 162127134 3567670201 3173514764 +conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 162127134 3567670201 363897018 +conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2462511240 1350938697 1696306119 1005311005 +conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3884703009 3552725366 1975514757 1210310496 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2056905385 447674669 724481645 1457430910 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 1227883689 3401425854 3897766524 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 3749787834 3350064812 1136116240 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3414629540 820341033 770836461 2451581199 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4100326666 2581696511 1088458082 1521190911 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3662895757 2885454895 935600441 2615245898 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 3831334389 3506139121 814982501 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2154102133 737968461 1291834254 2665225480 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 1809195644 1765637461 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 4120764770 3573498719 3379808294 483095299 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1972540475 4194153035 2863868771 1639389008 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2624318208 157618421 1779474147 814087242 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2044596379 2300180628 423968553 3890279569 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2609259399 1848932917 522753581 1926508271 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2948772873 3663040534 4014266327 1288646188 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3271403719 1585195072 1487505772 3253374264 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 1419588777 451194147 3578359696 3659768981 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 763924990 2780826684 2883769406 148530958 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2578426561 3849874822 102765469 1305171059 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1516344656 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 740110603 1995451256 2632815435 1586331550 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 2462511240 2274021368 1188866747 3178890497 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 752289976 1226457131 4187777346 1400559240 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 fnhwc_fnhwc_fnhwc_f_f 3723912751 1585959358 3731079159 1498901684 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 fnhwc_fnhwc_fnhwc_f_f 2027599472 2758666204 3287095476 4291916486 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3393706648 3519979618 1149261202 799742106 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3409586999 409840186 1724648597 2642018980 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1815685330 1398622058 2431638856 1016967269 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2555706782 3271563943 1020153035 299097281 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4173830187 736684125 472021975 2064613035 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3010335403 2751224679 2250540122 3725638844 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3010335403 1583610315 3287895411 2394340435 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3010335403 2356047354 7055632 915702611 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2748205217 2539405983 1217377670 2011175578 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2748205217 2114448427 249997769 2711364520 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1528321643 1532777511 3597171412 296622236 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1326617037 3415095747 847196866 1481554158 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1122706355 2841974626 2791878604 632900093 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1728385278 2462678309 3066040807 1334515660 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2175275779 1117731224 857614711 2096711962 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4140401170 3710340185 1683575469 317397427 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3552249008 2918315307 2290683130 536859016 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2869959072 2516947012 3328285094 2393284712 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1349264322 1823945068 400087667 2893025864 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3321662203 426084311 4233055093 4078572279 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3321662203 426084311 4233055093 3044377475 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 803041205 2521863610 3206942690 127091020 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4083508736 37801570 240515127 2234797539 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2207374588 535059558 2268619394 1489214085 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 dnhwc_dnhwc_dnhwc_d_d 3614026280 1721563676 2979825951 1104908081 +conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 2226238626 2053372396 2462697514 +conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 235646718 1374133172 3696289981 +conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2536722089 184705847 3148323124 84213385 +conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2264868815 1724845245 3498302256 4094034457 +conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1621735632 233390337 1801952602 3532884734 +conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3048346885 2306163504 642074123 4083120683 +conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2798030672 683783039 3025345160 1890891136 +conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1731071506 1844675436 2292509333 4006304179 +conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 132147677 604503886 143348844 3037223953 +conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 1678940393 3405733837 1820114523 +conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 1678940393 3405733837 467254076 +conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1794301352 2320042028 2134048179 508141072 +conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 561590023 3382154048 4154621995 517057927 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 593915463 2360210889 2685491481 2265099675 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 2226238626 1155815529 558646991 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2536722089 1876429398 4216128545 1754596046 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 348523586 2609019785 3938405680 2601133907 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1984146316 1475870285 1157657800 1143965395 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2971058593 1478256319 503014742 3930504182 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1621735632 1214508920 1537003531 3830217225 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2031518387 2695641559 933408074 4026827730 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 517276344 1158854831 3123629043 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 810746104 517276344 1448394173 1864626308 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2536722089 711164468 2465036841 2993377049 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2264868815 3003481795 333430991 3094857755 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1621735632 1126010692 3313703859 637497110 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1130094757 2605103293 2477101661 1276123281 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 4286533436 1302900889 2613245986 2523724148 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 3048346885 923365529 1681226722 417509256 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2798030672 3441819646 1293178065 188472807 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1731071506 1117530547 2706270359 502156742 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 132147677 2029225588 3851064913 3164530726 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 2337137106 3312954197 2466682688 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1269799445 2337137106 3312954197 2684544683 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 1794301352 72938921 2354994612 1463501392 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 252570564 2903451081 3619280116 1448586411 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 dnhwc_dnhwc_dnhwc_d_d 2037991187 1665743881 241585763 103256264 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 dnhwc_dnhwc_dnhwc_d_d 2653975581 3337638999 1440125233 2448165745 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 991402150 1393431534 1148212814 1350914659 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4208297221 4283492776 419570292 1210341563 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4178596783 3828059710 2735749436 2671012171 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 924522595 563724475 3750778972 4152580670 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1021044158 1686067905 3765040166 4102272733 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 2674994719 635224486 2759329777 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 4201252830 2920298728 304256151 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3335547 70289262 646435722 4137562540 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 1288095320 2132879813 656196754 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1317457392 2202157489 2326567490 2475188414 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2476454437 1857118302 4164386062 239840568 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2767650699 3514840131 590439733 3879821123 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3896287283 3112762669 2515107934 2106635937 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1903067870 1021832870 3003938078 2751931686 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3489785028 2466126497 1374078692 2737628040 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2051350923 263676708 3639860119 1370886256 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 719099834 1474713672 204857540 2768940347 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3441724486 3162593831 421721594 3097845598 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2034354027 1249407570 2567025479 1441082595 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 2369653089 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 941893937 3608468045 635631428 1218705038 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 172579142 319546523 718795680 1453661415 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2823351660 1326352711 1110204809 1155441703 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3238446487 2572503545 686287700 1559476701 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_f_f 2149247508 1775375365 3317647029 2497607448 +conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 3464637181 1623218578 436154205 +conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4110991321 1479940693 3253144559 3883419107 +conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 832653836 1871463331 2425320272 74566211 +conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3484040069 664160900 3610888033 22347127 +conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1513864544 1924855848 1382111427 2541177413 +conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 868180534 1764715518 3070473696 2392864704 +conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3437976747 666906244 3401957738 2050602745 +conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 4195072693 1575210381 781892324 2848949054 +conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3457330201 2316839359 1539389419 4293781748 +conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 2693098375 +conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 754609939 2469024119 2885305868 1969608051 +conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 1690216859 554790212 2885143346 780489333 +conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_fnhwc_f_f 3184127693 835105643 3337423971 3866137775 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 927718585 4106152802 720400339 3989318043 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 3464637181 4051957661 126285749 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 3723472741 2044236350 2463899842 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 2075083065 2042513140 3691286135 322550345 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4005590448 1116254439 2328237343 1918824440 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 181075276 1743485155 3526891198 1979405632 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 386662952 4057300775 1456746562 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 856324887 3954249564 2340393915 4127188930 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1300426008 2921497047 4145791960 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4110991321 1300426008 4080981223 3076991942 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 832653836 447261065 3823545045 392205236 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3484040069 2966693627 3900095420 919511892 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1513864544 1759979610 4272621682 1029257940 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1906605830 2980501720 978889789 3136018973 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 805717279 3502822733 1810065278 1387739380 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 868180534 3289288595 209477462 4142168174 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3437976747 3391080565 97275649 4063718293 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 4195072693 1669352457 2182133559 2494741804 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3457330201 1126870455 319272291 3811977088 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 1723074453 1660326213 3902884425 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 754609939 1723074453 1660326213 423159249 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 1690216859 2413490039 223529410 3303697952 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 3168796339 1601750164 1428743330 403295189 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_f_f 261954979 1300976652 2749562370 3058142403 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_f_f 3747142491 1747587481 3143977827 835130482 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1736512560 49406874 846358010 3314905564 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1848484956 1432417472 1903569827 3750799351 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4236427320 3696009469 69852620 201921851 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 109006944 450017448 1793784844 903209915 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 813367872 2397796503 1928191746 3210229460 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1307184141 46021356 1674017987 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 1212511562 3331767121 2446286369 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1348284291 2013675943 1681111033 1469213228 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 500298386 3218034344 4159283207 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1703349794 1123534155 145385311 4273847179 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3862659311 349459322 1503631520 1404971956 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1623686755 961217371 552550209 3980749384 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3554927580 1131648083 4149599295 3119557776 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1767639287 3350675774 128324027 1059816532 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3986143536 17411088 40173029 1694092310 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1157793540 3513299281 48848814 1435528367 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 988962069 4292634763 388976034 2674929544 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4202383208 3529769234 1046186503 3368902675 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 856448884 3057259762 2063087558 1995545427 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 400986166 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2281940872 144496548 2455451862 1082696406 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2702905851 1992889713 731289041 608504198 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2742293143 4197915274 606840 3671124731 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 149434841 2288560511 2994968424 2881838300 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 2226824643 327135318 3718671210 2121176659 +conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1027662440 4172720592 446082987 +conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1101653138 3727072529 875733988 +conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3906526127 655926291 939844058 +conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2784049299 2031878085 1709408312 1277173429 +conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 22652410 1700696921 2175632852 +conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1530672622 436588210 470857851 284463232 +conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1500864134 59350507 969037229 1510558485 +conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3344871528 856797938 2030818524 4231831552 +conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 966721255 2885833872 2829967135 3441569557 +conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 4148824382 2827420298 378131261 +conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 4148824382 2827420298 2955292920 +conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4028893260 1474248671 1302526250 4182204885 +conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1569788048 162506176 819639712 763595635 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 289918791 1266976707 942688231 3457364823 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 1027662440 2005082293 2235558527 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 3380032042 1370040310 1348846927 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 671982235 1423304149 2107662762 1234913781 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 798317794 1709026638 2421185623 3308071321 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1721270411 2519327328 2541413264 3185574975 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2070174510 1364436192 3531942595 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2128738105 2056902987 3079166829 2329433528 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3227877956 645422556 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 672413387 3857917762 3817218800 985231315 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2754803027 1398036015 3630062764 2492522537 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2784049299 643733019 3649549642 2637869234 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2756413475 2332160299 302086821 3303132343 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1931093565 2458714707 2919710256 2311575036 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2472246681 2260022344 500095455 2760458995 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1530672622 3635363851 2402907878 4131497953 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 1500864134 2536338700 2459524764 2504484273 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 3344871528 2667385029 2714805835 3487838445 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 966721255 1547169349 3198573835 302049294 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 1317923157 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2643693957 2440004820 1576818970 3186679687 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 4028893260 4220759192 2236533218 3731336532 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 2956871200 1591352238 1756650151 1262787222 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 hnhwc_hnhwc_hnhwc_h_h 365467186 892422645 1334708242 1372556938 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 hnhwc_hnhwc_hnhwc_h_h 3347784734 150035460 2897171548 3701081496 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 945660191 3750377696 2496492611 3515056508 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2806300501 2591577756 3148637036 3845512743 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2322444122 3525997046 281106520 3456307300 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 327345109 1137297282 1938163814 2551101563 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 797067973 481331945 350851834 2477733239 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1316460560 2044204046 1034822169 3340281844 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1316460560 4174274001 1597212204 1881272946 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1316460560 1535088984 3001492060 2308505016 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3190527989 3733991924 4211138051 3710311115 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3190527989 3430768821 1043108884 4185640072 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 943531303 1948306075 3877008798 2803592376 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3262141476 4125717435 2946529611 2221512094 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1599291337 3982786366 1581171257 1188352423 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2237070215 3046262465 1926804094 1435916873 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 721666814 2012769306 1712378956 1388990183 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1596349869 3775131163 355203300 1126174452 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1380587417 1208642645 2886387159 3113955983 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1332573203 1417735573 1422796372 3309229181 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2714027800 2106992819 1196036582 2095126659 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1105097447 1992731268 2198911423 3378137735 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1105097447 1992731268 2198911423 3868431311 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2552471160 2218470296 2332616929 923645661 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2231354584 4035702005 3839068434 8981294 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 4019719318 3985307916 3604065639 277096636 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 bf16nhwc_bf16nhwc_fnhwc_f_f 258381429 3482776077 2663631601 593179089 +conv2d dgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 188810648 1623218578 2585892217 +conv2d dgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 691990354 3253144559 2988350639 +conv2d dgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2788041828 1670375523 2425320272 2553108650 +conv2d dgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1049321188 1865889553 3610888033 1459693945 +conv2d dgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3820648800 3236781482 1382111427 1986396315 +conv2d dgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 463742721 2524037630 3070473696 210045128 +conv2d dgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 738614177 4071452982 3401957738 2920893800 +conv2d dgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2479111539 2662555669 781892324 2338234282 +conv2d dgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2089076160 260434096 1539389419 1219120658 +conv2d dgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 3344412669 2885305868 1926445693 +conv2d dgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 3344412669 2885305868 1478058549 +conv2d dgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3945616248 4118489020 2885143346 1545684873 +conv2d dgrad_1x56x56x8_28x28_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 295760528 1685244361 3337423971 772814550 +conv2d wgrad_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 623727338 942771643 2634710231 3063349371 +conv2d wgrad_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 188810648 2709881923 3532383400 +conv2d wgrad_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2788041828 3762161398 3733128758 3693097785 +conv2d wgrad_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 139944998 3812563855 253288229 1359907535 +conv2d wgrad_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 492562992 3677108443 525487530 445191233 +conv2d wgrad_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 594197095 3773864559 91136873 4170763393 +conv2d wgrad_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3820648800 1025574686 1127709182 677727764 +conv2d wgrad_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1901075489 3296829308 2591894666 2932517926 +conv2d wgrad_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 4223561525 1263618595 50680160 +conv2d wgrad_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1613915345 4223561525 1756414462 3209752057 +conv2d wgrad_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2788041828 1023542180 121940906 624551470 +conv2d wgrad_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1049321188 296097075 1423016429 1058165639 +conv2d wgrad_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3820648800 4160685370 2761559427 1788182893 +conv2d wgrad_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1859384988 222880684 1650970502 1632078530 +conv2d wgrad_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 1704522433 2403392926 3985958544 1432584676 +conv2d wgrad_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 463742721 3455033786 385631111 1683348880 +conv2d wgrad_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 738614177 3199562330 1513955316 2131256035 +conv2d wgrad_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2479111539 2702777753 2608107448 4014212857 +conv2d wgrad_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2089076160 4042009058 106232038 1140762595 +conv2d wgrad_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 2260768172 1186911503 3194129408 +conv2d wgrad_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 14838294 2260768172 1186911503 1312312812 +conv2d wgrad_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 3945616248 2287161276 36034283 4262860382 +conv2d wgrad_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 2906914535 476297538 14375779 1340176713 +conv2d wgrad_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 bf16nhwc_bf16nhwc_fnhwc_f_f 4292101959 3378414564 4259930640 1392755176 +conv2d wgrad_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 bf16nhwc_bf16nhwc_fnhwc_f_f 3529371817 368260304 4137156526 122558013 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 2948718568 2631391783 3260825675 4278587299 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 1635109696 2835574424 4179385325 2803281440 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 3344954627 1649157278 2032056735 1176638626 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 61750237 3452849177 1697665310 3475459781 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 1394759191 1571308277 898534533 4125341936 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 3402206912 2433594404 1575577431 4106154211 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 98638790 2735493952 346473870 1911666301 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 98638790 2735493952 346473870 2124440208 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 2934485636 3286257323 541566528 1113783492 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nc32hw32_s8c32rsk32_s8nc32hw32_i_f 164942943 4259285988 1250700182 508419908 +conv2d fprop_1x1x1x64_3x3_8x1x1_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3805460372 2607401558 3465030781 210641751 +conv2d fprop_1x1x8x64_3x8_8x1x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 4200926784 1001915027 387475271 3360115596 +conv2d fprop_1x7x8x64_7x8_8x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 331078659 469730619 2547196469 1620698703 +conv2d fprop_1x7x9x64_6x8_8x4x4_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 431968022 1614654085 903827412 1349891842 +conv2d fprop_2x7x9x64_5x7_8x5x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3674369485 1055554271 3217013807 1356703347 +conv2d fprop_3x7x9x64_4x7_8x6x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2609462247 3227824772 365527403 2720889763 +conv2d fprop_3x7x9x64_4x6_8x6x6_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2609462247 2150996976 2899308770 2371758816 +conv2d fprop_3x7x9x64_3x5_8x7x7_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2609462247 2124373651 2711906981 3194739760 +conv2d fprop_1x11x7x64_6x4_8x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1070162100 2750964634 3090791018 3481982191 +conv2d fprop_1x11x7x64_6x4_8x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1070162100 1563941622 767747438 3163252390 +conv2d fprop_1x13x11x64_8x7_8x1x1_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 884815233 3576251756 3216742798 3534462723 +conv2d fprop_1x17x19x64_9x10_16x2x2_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3230717758 3192193994 1161445944 371179683 +conv2d fprop_1x23x5x64_12x3_16x3x3_pad_h1w1_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2450454245 2905280248 910194866 839083662 +conv2d fprop_1x23x21x128_23x21_224x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2948718568 2631391783 638794727 4292051282 +conv2d fprop_1x16x24x128_16x24_96x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1635109696 2835574424 1855687620 130932480 +conv2d fprop_1x55x51x256_28x26_512x1x1_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3344954627 1649157278 4191418350 958044197 +conv2d fprop_1x27x23x256_9x7_512x3x3_pad_h0w0_stride_h3w3_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 61750237 3452849177 3260472389 771128506 +conv2d fprop_1x27x31x256_12x11_512x3x3_pad_h5w7_stride_h3w4_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 1394759191 1571308277 4279538191 956191103 +conv2d fprop_1x27x35x256_15x9_512x7x5_pad_h11w7_stride_h3w5_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 3402206912 2433594404 2021112123 2983097553 +conv2d fprop_1x27x27x256_27x14_512x3x3_pad_h1w1_stride_h1w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 98638790 2735493952 3178839372 568554158 +conv2d fprop_1x27x27x256_14x27_512x3x3_pad_h1w1_stride_h2w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 98638790 2735493952 3178839372 18194802 +conv2d fprop_3x28x28x256_14x14_256x2x2_pad_h0w0_stride_h2w2_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2934485636 3286257323 2559221535 2310182528 +conv2d fprop_4x4x5x128_3x3_256x3x6_pad_h0w0_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 164942943 4259285988 984016853 888753301 +conv2d fprop_4x2x3x256_1x1_328x3x5_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha1_beta0 s8nhwc_s8nhwc_inhwc_i_i 2823094147 1681845497 4242738907 3244428635 +conv2d fprop_1x17x11x288_17x11_160x3x3_pad_h1w1_stride_h1w1_dil_h1w1_corr_alpha2_beta2 s8nhwc_s8nhwc_inhwc_i_i 4060010502 2881035321 3927119619 3311661122 diff --git a/tools/library/scripts/pycutlass/test/unit/test_sm80.py b/tools/library/scripts/pycutlass/test/unit/test_sm80.py new file mode 100644 index 00000000..bedb3a3a --- /dev/null +++ b/tools/library/scripts/pycutlass/test/unit/test_sm80.py @@ -0,0 +1,440 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +## Test case generator for SM80 + +import pycutlass +from pycutlass import * +from pycutlass.test import * +import unittest + +# +# Create GEMM operation +# + +def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False, + epilogue_functor = EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1, **kwargs): + """ + Test GEMM Operation based on configuration + """ + + if "data_type" in kwargs.keys(): + data_type = kwargs["data_type"] + else: + if mixed or math_inst.element_a == cutlass.bfloat16: + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator + ] + else: + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator + ] + + tile_description = TileDescription( + tiling[0], tiling[1], tiling[2], + math_inst, arch, arch + ) + + A = TensorDescription( + data_type[0], layout[0], alignment[0] + ) + + B = TensorDescription( + data_type[1], layout[1], alignment[1] + ) + + C = TensorDescription( + data_type[2], layout[2], alignment[2] + ) + + element_epilogue = data_type[3] + + if gemm_kind == GemmKind.Universal: + operation = GemmOperationUniversal( + arch=arch, tile_description=tile_description, + A=A, B=B, C=C, element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor + ) + if A.layout in [cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32]: + return test_all_gemm(operation, "interleaved") + else: + return test_all_gemm(operation, "universal") + + elif gemm_kind == GemmKind.Grouped: + operation = GemmOperationGrouped( + arch, tile_description, A, B, C, + element_epilogue, epilogue_functor, swizzling_functor, + precompute_mode=kwargs["precompute_mode"] + ) + testbed = TestbedGrouped(operation=operation) + return testbed.run(24) + else: + raise NotImplementedError("the gemm kind is not implemented") + + +def TestConv2dOperator(math_inst, alignment, tiling, arch, + stride_supports=[StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided], + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=cutlass.IdentitySwizzle1, interleaved=False, **kwargs): + """ + Test Conv2d Operation based on configurations + """ + + mixeds = [False, True, False] + conv_kinds = [cutlass.conv.Operator.fprop, cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad] + + results = [] + + default_swizzling_functor = swizzling_functor + + if "layout" in kwargs.keys(): + layout = kwargs["layout"] + else: + layout = (cutlass.TensorNHWC, cutlass.TensorNHWC, cutlass.TensorNHWC) + + for mixed, conv_kind, stride_support in zip(mixeds, conv_kinds, stride_supports): + + if "data_type" in kwargs.keys(): + data_type = kwargs["data_type"] + else: + if mixed or math_inst.element_a == cutlass.bfloat16: + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator + ] + else: + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator + ] + # skip Int8 Conv Backward + if data_type[0] == cutlass.int8 and conv_kind in [cutlass.conv.Operator.dgrad, cutlass.conv.Operator.wgrad]: + continue + + A = TensorDescription( + element=data_type[0], + layout=layout[0], + alignment=alignment[0]) + B = TensorDescription( + element=data_type[1], + layout=layout[1], + alignment=alignment[1]) + C = TensorDescription( + element=data_type[2], + layout=layout[2], + alignment=alignment[2]) + + tile_description = TileDescription( + threadblock_shape=tiling[0], stages=tiling[1], + warp_count=tiling[2], + math_instruction=math_inst, + min_compute=arch, max_compute=arch + ) + + if conv_kind == cutlass.conv.Operator.dgrad and stride_support == StrideSupport.Strided: + swizzling_functor = cutlass.StridedDgradIdentitySwizzle1 + else: + swizzling_functor = default_swizzling_functor + + operation = Conv2dOperation( + conv_kind=conv_kind, iterator_algorithm=cutlass.conv.IteratorAlgorithm.optimized, + arch=arch, tile_description=tile_description, A=A, B=B, C=C, + element_epilogue=data_type[3], stride_support=stride_support, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor + ) + + results.append(test_all_conv2d(operation, interleaved=interleaved)) + + return results + + + +class Test_SM80(unittest.TestCase): + def test_SM80_TensorOp_16816(self): + math_instructions = [ + MathInstruction( + [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add + ), + MathInstruction( + [16, 8, 16], cutlass.float16, cutlass.float16, cutlass.float16, + cutlass.OpClass.TensorOp, MathOperation.multiply_add + ), + MathInstruction( + [16, 8, 16], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add + ) + ] + + layouts = [ + (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor), + (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.RowMajor), + (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.RowMajor) + ] + + alignments = [ + (8, 8, 8), (4, 8, 8), (8, 4, 8) + ] + + tilings = [ + ([256, 128, 32], 3, [4, 2, 1]), + ([64, 256, 32], 4, [1, 4, 1]), + ([128, 64, 64], 3, [2, 2, 1]) + ] + + for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings): + self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False)) + self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host)) + stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] + results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports) + for res in results: + self.assertTrue(res) + + def test_SM80_TensorOp_1688(self): + # tf32 is not supported by most of python environment. Skip the test + self.assertTrue(True) + + def test_SM80_TensorOp_1688_fast_math(self): + math_instructions = [ + MathInstruction( + [16, 8, 8], cutlass.tfloat32, cutlass.tfloat32, cutlass.float32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add + ), + MathInstruction( + [16, 8, 8], cutlass.float16, cutlass.float16, cutlass.float32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f16 + ), + MathInstruction( + [16, 8, 8], cutlass.bfloat16, cutlass.bfloat16, cutlass.float32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_bf16 + ), + MathInstruction( + [16, 8, 8], cutlass.float32, cutlass.float32, cutlass.float32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add_fast_f32 + ) + ] + + layouts = [ + (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor), + (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor), + (cutlass.ColumnMajor, cutlass.RowMajor, cutlass.ColumnMajor), + (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.RowMajor) + ] + alignments = [ + (4, 4, 4), (4, 2, 4), (2, 4, 4), (2, 2, 4) + ] + tilings = [ + ([128, 256, 16], 3, [4, 2, 1]), + ([64, 256, 16], 4, [1, 4, 1]), + ([128, 64, 32], 3, [2, 2, 1]), + ([256, 64, 32], 3, [4, 2, 1]) + ] + data_type = [ + cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32 + ] + for math_inst, layout, alignment, tiling in zip(math_instructions, layouts, alignments, tilings): + self.assertTrue( + TestGemmOperator( + GemmKind.Universal, math_inst, layout, + alignment, tiling, 80, False, data_type=data_type)) + self.assertTrue( + TestGemmOperator( + GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, + True, precompute_mode=SchedulerMode.Device, data_type=data_type)) + stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity] + results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) + for res in results: + self.assertTrue(res) + + def test_SM80_TensorOp_884(self): + math_inst = MathInstruction( + [8, 8, 4], cutlass.float64, cutlass.float64, cutlass.float64, + cutlass.OpClass.TensorOp, MathOperation.multiply_add + ) + layout = (cutlass.ColumnMajor, cutlass.ColumnMajor, cutlass.ColumnMajor) + alignment = (1, 1, 1) + + tiling = ([64, 256, 16], 3, [2, 4, 1]) + data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64] + self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) + self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type)) + stride_supports = [StrideSupport.Unity, StrideSupport.Strided, StrideSupport.Unity] + results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) + for res in results: + self.assertTrue(res) + + def test_SM80_TensorOp_16832_TN(self): + math_inst = MathInstruction( + [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate + ) + layout = (cutlass.RowMajor, cutlass.ColumnMajor, cutlass.ColumnMajor) + alignment = (16, 16, 4) + alignment_mixed = (16, 16, 16) + tiling = ([128, 256, 64], 3, [2, 4, 1]) + + data_type = [cutlass.int8, cutlass.int8, cutlass.int32, cutlass.int32] + data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32] + + self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) + self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment_mixed, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type_mixed)) + stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] + results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) + for res in results: + self.assertTrue(res) + + def test_SM80_Simt_f32(self): + math_inst = MathInstruction( + [1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32, + cutlass.OpClass.Simt, MathOperation.multiply_add + ) + layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.RowMajor) + alignment = (1, 1, 1) + + tiling = ([128, 256, 8], 4, [2, 4, 1]) + data_type = [cutlass.float32, cutlass.float32, cutlass.float32, cutlass.float32] + self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) + self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Host, data_type=data_type)) + stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] + results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) + for res in results: + self.assertTrue(res) + + def test_SM80_Simt_f64(self): + math_inst = MathInstruction( + [1, 1, 1], cutlass.float64, cutlass.float64, cutlass.float64, + cutlass.OpClass.Simt, MathOperation.multiply_add + ) + layout = (cutlass.RowMajor, cutlass.RowMajor, cutlass.ColumnMajor) + alignment = (1, 1, 1) + + tiling = ([64, 128, 8], 5, [2, 2, 1]) + data_type = [cutlass.float64, cutlass.float64, cutlass.float64, cutlass.float64] + self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment, tiling, 80, False, data_type=data_type)) + self.assertTrue(TestGemmOperator(GemmKind.Grouped, math_inst, layout, alignment, tiling, 80, True, precompute_mode=SchedulerMode.Device, data_type=data_type)) + stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] + results = TestConv2dOperator(math_inst, alignment, tiling, 80, stride_supports=stride_supports, data_type=data_type) + for res in results: + self.assertTrue(res) + + def test_SM80_TensorOp_16832_Interleaved(self): + math_inst = MathInstruction( + [16, 8, 32], cutlass.int8, cutlass.int8, cutlass.int32, + cutlass.OpClass.TensorOp, MathOperation.multiply_add_saturate + ) + + layout = (cutlass.ColumnMajorInterleaved32, cutlass.RowMajorInterleaved32, cutlass.ColumnMajorInterleaved32) + alignment_mixed = (16, 16, 8) + tiling = ([256, 64, 64], 4, [4, 1, 1]) + data_type_mixed = [cutlass.int8, cutlass.int8, cutlass.int8, cutlass.float32] + + self.assertTrue(TestGemmOperator(GemmKind.Universal, math_inst, layout, alignment_mixed, tiling, 80, False, data_type=data_type_mixed, epilogue_functor=EpilogueFunctor.FastLinearCombinationClamp)) + stride_supports = [StrideSupport.Strided, StrideSupport.Strided, StrideSupport.Strided] + layout = [cutlass.TensorNC32HW32, cutlass.TensorC32RSK32, cutlass.TensorNC32HW32] + results = TestConv2dOperator(math_inst, alignment_mixed, tiling, 80, stride_supports=stride_supports, data_type=data_type_mixed, layout=layout, interleaved=True) + for res in results: + self.assertTrue(res) + + def SM80_SparseTensorOp_16832(self): + pass + def test_SM80_PlanarComplexTensorOp_16816(self): + pass + def test_SM80_SparseTensorOp_16816_fast_math(self): + pass + def test_SM80_TensorOp_1688_complex(self): + pass + def test_SM80_TensorOp_1688_fast_fp32_math_complex(self): + pass + def test_SM80_TensorOp_1688_rank_k(self): + pass + def test_SM80_TensorOp_1688_rank_k_complex(self): + pass + def test_SM80_TensorOp_1688_trmm(self): + pass + def test_SM80_TensorOp_1688_trmm_complex(self): + pass + def test_SM80_TensorOp_1688_symm(self): + pass + def test_SM80_TensorOp_1688_symm_complex(self): + pass + def test_SM80_TensorOp_884_complex(self): + pass + def test_SM80_TensorOp_884_complex_gaussian(self): + pass + def test_SM80_TensorOp_884_rank_k(self): + pass + def test_SM80_TensorOp_884_rank_k_complex(self): + pass + def test_SM80_TensorOp_884_rank_k_complex_gaussian(self): + pass + def test_SM80_TensorOp_884_trmm(self): + pass + def test_SM80_TensorOp_884_trmm_complex(self): + pass + def test_SM80_TensorOp_884_trmm_complex_gaussian(self): + pass + def test_SM80_TensorOp_884_symm(self): + pass + def test_SM80_TensorOp_884_symm_complex(self): + pass + def test_SM80_TensorOp_884_symm_complex_gaussian(self): + pass + def test_SM80_SparseTensorOp_16864_TN(self): + pass + def test_SM80_TensorOp_16864_TN(self): + pass + def test_SM80_SparseTensorOp_168128_TN(self): + pass + def test_SM80_TensorOp_16864_Interleaved(self): + pass + def test_SM80_TensorOp_168256(self): + pass + def test_SM80_Simt_complex(self): + pass + + +if __name__ == '__main__': + pycutlass.get_memory_pool(2**20, 2**34) + pycutlass.compiler.nvcc() + unittest.main() diff --git a/tools/util/include/cutlass/util/device_layernorm.h b/tools/util/include/cutlass/util/device_layernorm.h new file mode 100644 index 00000000..c7af287b --- /dev/null +++ b/tools/util/include/cutlass/util/device_layernorm.h @@ -0,0 +1,644 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. + * \tparam T: data type + */ +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + T local_val[ITEM_PER_THREAD]; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + const T zero = T(0.0f); + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + local_val[i] = index < n ? input[index] : zero; + local_sums[0] += static_cast(local_val[i]); + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n){ + const float tmp = static_cast(local_val[i]) - s_mean; + local_sums[0] += tmp * tmp; + } + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n) { + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T2 local_val[ITEM_PER_THREAD]; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + const T2 zero = {T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_2 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const float2 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, + const T4* input, + const T4* gamma, + const T4* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T4 local_val[ITEM_PER_THREAD]; + const int n_4 = n / 4; + int offset = m_idx * n_4; + input += offset; + output += offset; + + const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_4 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + + static_cast(local_val[i].z) + static_cast(local_val[i].w); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const float4 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean, + static_cast(local_val[i].z) - s_mean, + static_cast(local_val[i].w) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const T4 gamma_val = gamma[index]; + const T4 beta_val = beta[index]; + T4 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); + tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_val = local_val - s_mean; + local_sums[0] += local_val * local_val; + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + const T local_val = input[index]; + output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const float2 tmp = {static_cast(local_val.x) - s_mean, + static_cast(local_val.y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } +} + +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + dim3 grid(m); + dim3 block((n + 31)/32*32); + if (block.x > 1024){ + block.x = 1024; + } + // TODO : There should be better configs for different cases, we only use several samples to show how to use here + // TODO : using registers to store values locally can reduce the ldgs from global memory and speedup the kernels. + if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { + block.x = (n/4 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (float4*)output, + (const float4*)input, + (const float4*)gamma, + (const float4*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (half4*)output, + (const half4*)input, + (const half4*)gamma, + (const half4*)beta, + m, + n); + } + } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) + else if (n % 2 == 0) { + if (n / 2 <= 1024) { + block.x = (n/2 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } //if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n / 2 <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/ 16 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 32768) + else { + if (block.x > 512) + block.x = 512; + if (std::is_same::value) { + layernorm_twoPassAlgo_e2<<>>( + (float2 *)output, + (const float2 *)input, + (const float2 *)gamma, + (const float2 *)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_e2<<>>( + (half2 *)output, + (const half2 *)input, + (const half2 *)gamma, + (const half2 *)beta, + m, + n); + } + } + } // if (n % 2 == 0) + else { + if (n <= 1024) { + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/16 + 32)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 32768) + else{ + if (block.x > 512) { + block.x = 512; + } + layernorm_twoPassAlgo_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } + } +} + +} //namespace cutlass diff --git a/tools/util/include/cutlass/util/device_nhwc_pooling.h b/tools/util/include/cutlass/util/device_nhwc_pooling.h new file mode 100644 index 00000000..3a61442f --- /dev/null +++ b/tools/util/include/cutlass/util/device_nhwc_pooling.h @@ -0,0 +1,576 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::MatrixCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream); + +/** get the output size of pooling + */ +inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) +{ + return (H_W + 2 * padding - kernel_size) / stride + 1; +} + +/** + * input is [N, H, W, C] + * assume stride == kernel_size + * output_h = (H + 2*padding_H - kernel_H)/stride_H + * output_w = (W + 2*padding_W - kernel_W)/stride_W + * output is [N, output_h, output_w, C] + * grid(N, output_h, output_w) + * block(min(C, 256)) : + * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) +*/ +template +__global__ void pooling_nhwc_element1_kernel(T* output, + const T* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float pooling; + if (IS_AVG_POOLING){ + pooling = 0.0f; + } + else{ + pooling = -FLT_MAX; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const float tmp = static_cast(input[idx + c_idx]); + if (IS_AVG_POOLING){ + pooling = pooling + tmp; + } + else{ + pooling = pooling > tmp ? pooling : tmp; + } + } + } + + T output_val; + if (IS_AVG_POOLING){ + output_val = T(pooling/kernel_size2); + } + else{ + output_val = T(pooling); + } + output[c_idx] = output_val; + } +} + +template +__global__ void pooling_nhwc_element2_kernel(T2* output, + const T2* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float2 pooling; + if (IS_AVG_POOLING) { + pooling = {0.0f, 0.0f}; + } + else { + pooling = {-FLT_MAX, -FLT_MAX}; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const T2 tmp = input[idx + c_idx]; + const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; + if (IS_AVG_POOLING) { + pooling.x += tmp_flt2.x; + pooling.y += tmp_flt2.y; + } + else { + pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; + pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; + } + } + } + + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling.x/kernel_size2); + output_val.y = T(pooling.y/kernel_size2); + } + else { + output_val.x = T(pooling.x); + output_val.y = T(pooling.y); + } + output[c_idx] = output_val; + } +} + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C, N) + * block(block_size) -- each block deals with H*W/block_size elements; +*/ +template +__global__ void pooling_nxhTo1x1_element1_kernel( + T* output, const T* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[1]; + if (IS_AVG_POOLING) { + pooling[0] = 0.0f; + } + else { + pooling[0] = -FLT_MAX; + } + const size_t input_offset = n_idx * HW * C + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + float val = static_cast(input[index * C]); + if (IS_AVG_POOLING) { + pooling[0] += val; + } + else { + pooling[0] = pooling[0] > val ? pooling[0] : val; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T output_val; + if (IS_AVG_POOLING) { + output_val = T(pooling[0] / HW); + } + else { + output_val = T(pooling[0]); + } + output[0] = output_val; + } +} + + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C/2, N) + * block(block_size) -- each thread deals with H*W/block_size * 2 elements; +*/ +template +__global__ void pooling_nxhTo1x1_element2_kernel( + T2* output, const T2* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[2]; + if (IS_AVG_POOLING) { + pooling[0] = pooling[1] = 0.0f; + } + else { + pooling[0] = pooling[1] = -FLT_MAX; + } + const int C_2 = C / 2; + const size_t input_offset = n_idx * HW * C_2 + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C_2 + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + T2 val = input[index * C_2]; + float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; + if (IS_AVG_POOLING) { + pooling[0] += val_flt2.x; + pooling[1] += val_flt2.y; + } + else { + pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; + pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling[0] / HW); + output_val.y = T(pooling[1] / HW); + } + else { + output_val.x = T(pooling[0]); + output_val.y = T(pooling[1]); + } + output[0] = output_val; + } +} + +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::Tensor4DCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream) { + + assert(input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.c()); + + assert(filter_tensor_size.h() == stride.row() && + filter_tensor_size.w() == stride.column()); + + const int N = input_tensor_size.n(); + const int H = input_tensor_size.h(); + const int W = input_tensor_size.w(); + const int C = input_tensor_size.c(); + const int padding_H = padding.h(); + const int padding_W = padding.w(); + const int kernel_H = filter_tensor_size.h(); + const int kernel_W = filter_tensor_size.w(); + const int stride_H = stride.row(); + const int stride_W = stride.column(); + + const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); + const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); + + assert(output_tensor_size.h() == output_H && + output_tensor_size.w() == output_W); + + if (C % 2 != 0) { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } // if (poolingType == 0) + else { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C < block.x) { + block.x = C; + } + if (poolingType == 0) { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (poolingType == 0) + else { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } // if (C % 2 != 0)) + else { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C/2, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C/2 < block.x) { + block.x = C/2; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } + } +} + +} //namespace cutlass diff --git a/tools/util/include/cutlass/util/device_utils.h b/tools/util/include/cutlass/util/device_utils.h new file mode 100644 index 00000000..a54b9894 --- /dev/null +++ b/tools/util/include/cutlass/util/device_utils.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief utils code for device cutlass code +*/ + +#pragma once + +#include +#include +#define FINAL_MASK 0xffffffff + +struct half4 { + half x, y, z, w; +}; + +template +__inline__ __device__ T warpReduceSum(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T* val) +{ + __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSum(val); + return (T)0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceMax(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); + } + warpReduceMax(val); + + return (T)0.0f; +} + diff --git a/tools/util/include/cutlass/util/reference/device/convolution.h b/tools/util/include/cutlass/util/reference/device/convolution.h index 8c00b779..71666d04 100644 --- a/tools/util/include/cutlass/util/reference/device/convolution.h +++ b/tools/util/include/cutlass/util/reference/device/convolution.h @@ -123,11 +123,17 @@ __global__ void Conv2dFprop( } } + int c_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + // Compute convolution for (int R = 0; R < problem_size.R; ++R) { for (int S = 0; S < problem_size.S; ++S) { for (int C = 0; C < problem_size.C; ++C) { + // Get group id of currnet channel + int c_group_idx = C / c_per_group; + // Load from activations tensor int filter_r = R; int filter_s = S; @@ -154,9 +160,10 @@ __global__ void Conv2dFprop( CUTLASS_PRAGMA_UNROLL for (int n = 0; n < kThreadN; ++n) { int thread_k = k_start + n; + int k_group_idx = thread_k / k_per_group; - if (thread_k < problem_size.K) { - element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C})); + if (thread_k < problem_size.K && k_group_idx == c_group_idx) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); } else { element_B[n] = ElementAccumulator(); diff --git a/tools/util/include/cutlass/util/reference/host/convolution.h b/tools/util/include/cutlass/util/reference/host/convolution.h index 0e395527..9d670c51 100644 --- a/tools/util/include/cutlass/util/reference/host/convolution.h +++ b/tools/util/include/cutlass/util/reference/host/convolution.h @@ -86,11 +86,14 @@ void Conv2dFprop( for (int q = 0; q < problem_size.Q; ++q) { for (int k = 0; k < problem_size.K; ++k) { + int group_idx = k / (problem_size.K / problem_size.groups); + int channels_per_group = problem_size.C / problem_size.groups; + ElementAccumulator acc = ElementAccumulator(); for (int r = 0; r < problem_size.R; ++r) { for (int s = 0; s < problem_size.S; ++s) { - for (int c = 0; c < problem_size.C; ++c) { + for (int c = 0; c < channels_per_group; ++c) { int filter_r = r; int filter_s = s; @@ -105,7 +108,7 @@ void Conv2dFprop( if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { - ElementA a = tensor_x.at({n, h, w, c}); + ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); ElementB b = tensor_w.at({k, r, s, c}); acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); @@ -137,21 +140,21 @@ template , typename InnerProductOp = multiply_add > -void Depsep_Fprop( - cutlass::TensorView tensor_A, +void Depsep_Fprop(cutlass::TensorView tensor_A, cutlass::TensorView tensor_B, cutlass::TensorView tensor_C, + cutlass::TensorView tensor_D, ElementCompute alpha, ElementCompute beta, - cutlass::Tensor4DCoord padding, - cutlass::Coord<2> conv_stride, - cutlass::Coord<2> dilation, + cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), + cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), + cutlass::Coord<2> dilation = cutlass::Coord<2>(), cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { - + ConvertOp convert_op; InnerProductOp inner_product_op; @@ -163,15 +166,13 @@ void Depsep_Fprop( ElementAccumulator acc = ElementAccumulator(); for (int r = 0; r < tensor_B.extent().h(); ++r) { for (int s = 0; s < tensor_B.extent().w(); ++s) { - if ((p * conv_stride[0] - padding[0] + r * dilation[0]) < tensor_A.extent().h() && - (p * conv_stride[0] - padding[0] + r * dilation[0]) >= 0 && - (q * conv_stride[1] - padding[2] + s * dilation[1]) < tensor_A.extent().w() && - (q * conv_stride[1] - padding[2] + s * dilation[1]) >= 0) { - ElementA a = tensor_A.at( - cutlass::make_Coord(n, - p * conv_stride[0] - padding[0] + r * dilation[0], - q * conv_stride[1] - padding[2] + s * dilation[1], - g)); + + // input activation H and W + int h = p * conv_stride[0] - padding[0] + r * dilation[0]; + int w = q * conv_stride[1] - padding[2] + s * dilation[1]; + + if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { + ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) @@ -185,7 +186,7 @@ void Depsep_Fprop( // Apply Epilogue, compute ElementCompute, convert and store ElementC ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); - tensor_C.at(cutlass::make_Coord(n, p, q, g)) = + tensor_D.at(cutlass::make_Coord(n, p, q, g)) = convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); } } diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index d2e29579..b76100b7 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -694,10 +694,6 @@ struct TensorFillSymmetricRandomUniformFunc { } }; - -// -// We expect to release this with CUTLASS 2.4. -akerr - /// Computes a random Uniform distribution and pads diagonal with zeros template < typename Element, ///< Element type