From 7d49e6c7e2f8896c47f586706e67e1fb215529dc Mon Sep 17 00:00:00 2001 From: Vijay Thakkar Date: Thu, 11 Apr 2024 21:33:40 -0400 Subject: [PATCH] Updates for CUTLASS 3.5.0 (#1468) --- CHANGELOG.md | 7 +- CMakeLists.txt | 15 +- README.md | 11 +- .../ampere_tensorop_conv2dfprop.cu | 7 +- ...pere_3xtf32_fast_accurate_tensorop_gemm.cu | 60 +- .../29_3xtf32_complex_gemm.cu | 52 +- .../ampere_3xtf32_tensorop_symm.cu | 50 +- include/cute/algorithm/cooperative_copy.hpp | 146 +-- include/cute/algorithm/cooperative_gemm.hpp | 487 ++++++--- include/cute/arch/copy_sm50.hpp | 72 ++ include/cute/arch/util.hpp | 161 +-- include/cute/atom/copy_atom.hpp | 1 + include/cute/atom/copy_traits.hpp | 58 +- include/cute/atom/copy_traits_sm50.hpp | 58 ++ include/cute/atom/copy_traits_sm90_im2col.hpp | 267 +++-- include/cute/atom/copy_traits_sm90_tma.hpp | 71 +- include/cute/atom/mma_atom.hpp | 1 + include/cute/atom/mma_traits.hpp | 30 +- include/cute/atom/mma_traits_sm90_gmma.hpp | 2 +- include/cute/config.hpp | 2 +- include/cute/int_tuple.hpp | 19 +- include/cute/layout.hpp | 204 ++-- include/cute/layout_composed.hpp | 8 +- include/cute/numeric/arithmetic_tuple.hpp | 191 ++-- include/cute/numeric/complex.hpp | 8 +- include/cute/pointer.hpp | 7 +- include/cute/pointer_flagged.hpp | 10 +- include/cute/tensor.hpp | 3 +- include/cute/util/type_traits.hpp | 14 + include/cutlass/array.h | 125 ++- include/cutlass/array_planar_complex.h | 20 +- include/cutlass/array_subbyte.h | 122 +-- include/cutlass/bfloat16.h | 10 - include/cutlass/cluster_launch.hpp | 16 +- include/cutlass/complex.h | 9 - include/cutlass/conv/conv2d_problem_size.h | 23 +- include/cutlass/conv/conv3d_problem_size.h | 24 +- include/cutlass/conv/convolution.h | 13 +- .../conv/device/conv_universal_adapter.hpp | 14 +- .../conv/device/implicit_gemm_convolution.h | 12 +- include/cutlass/conv/kernel/default_conv2d.h | 22 + .../conv/kernel/default_conv2d_fprop.h | 34 +- .../conv/kernel/default_conv2d_fprop_fusion.h | 2 +- .../kernel/default_conv2d_fprop_with_absmax.h | 2 +- .../default_conv2d_fprop_with_broadcast.h | 9 +- .../default_conv2d_fprop_with_reduction.h | 2 +- .../conv/kernel/default_conv2d_group_fprop.h | 2 +- .../conv/kernel/default_conv3d_fprop.h | 82 +- .../conv/kernel/default_conv3d_fprop_fusion.h | 2 +- .../default_conv3d_fprop_with_broadcast.h | 9 +- .../cutlass/conv/kernel/default_deconv2d.h | 983 ++++++++++++++++++ .../kernel/default_deconv2d_with_broadcast.h | 305 ++++++ .../cutlass/conv/kernel/default_deconv3d.h | 525 ++++++++++ .../kernel/default_deconv3d_with_broadcast.h | 303 ++++++ .../conv/kernel/default_depthwise_fprop.h | 4 +- .../cutlass/conv/kernel/direct_convolution.h | 2 +- .../conv/kernel/implicit_gemm_convolution.h | 7 +- ...cit_gemm_convolution_with_fused_epilogue.h | 2 +- ...sm90_implicit_gemm_tma_warpspecialized.hpp | 1 - ...rop_filter_tile_access_iterator_analytic.h | 31 +- ...op_filter_tile_access_iterator_optimized.h | 17 +- ...rop_filter_tile_access_iterator_analytic.h | 15 +- ...op_filter_tile_access_iterator_optimized.h | 14 +- include/cutlass/coord.h | 10 - include/cutlass/core_io.h | 9 - include/cutlass/cutlass.h | 10 - .../collective/collective_builder.hpp | 2 +- .../sm90_epilogue_tma_warpspecialized.hpp | 3 +- ...90_visitor_compute_tma_warpspecialized.hpp | 29 +- include/cutlass/epilogue/thread/activation.h | 2 +- .../linear_combination_bias_elementwise.h | 66 +- .../thread/linear_combination_clamp.h | 30 +- .../linear_combination_planar_complex.h | 39 +- .../threadblock/default_epilogue_simt.h | 31 +- .../threadblock/default_epilogue_tensor_op.h | 22 +- .../default_epilogue_with_broadcast.h | 54 + .../threadblock/output_iterator_parameter.h | 52 +- .../threadblock/predicated_tile_iterator.h | 60 +- .../predicated_tile_iterator_conv.h | 562 ++++++++++ .../predicated_tile_iterator_params.h | 12 - ....h => shared_load_iterator_pitch_linear.h} | 4 +- include/cutlass/fast_math.h | 9 - include/cutlass/float8.h | 9 - include/cutlass/functional.h | 42 +- ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 4 +- .../gemm/device/gemm_universal_adapter.h | 12 +- include/cutlass/gemm/gemm_enumerated_types.h | 10 - include/cutlass/gemm/kernel/gemm_universal.h | 46 +- .../sm90_gemm_warpspecialized_cooperative.hpp | 1 - .../sm90_gemm_warpspecialized_pingpong.hpp | 1 - .../gemm/kernel/tile_scheduler_params.h | 10 - include/cutlass/gemm/thread/mma_sm60.h | 49 +- .../threadblock/threadblock_swizzle_streamk.h | 10 - include/cutlass/half.h | 10 - include/cutlass/integer_subbyte.h | 10 - include/cutlass/kernel_hardware_info.h | 10 - include/cutlass/layout/matrix.h | 10 - include/cutlass/layout/pitch_linear.h | 10 - include/cutlass/numeric_conversion.h | 309 ++++-- include/cutlass/numeric_size.h | 10 - include/cutlass/numeric_types.h | 9 - include/cutlass/pipeline/sm90_pipeline.hpp | 1 - include/cutlass/platform/platform.h | 14 +- .../predicated_tile_access_iterator_params.h | 10 - include/cutlass/uint128.h | 9 - include/cutlass/workspace.h | 10 - .../building_with_clang_as_host_compiler.md | 14 +- media/docs/cute/02_layout_algebra.md | 20 +- media/docs/fundamental_types.md | 7 +- media/docs/profiler.md | 1 - python/cutlass/backend/gemm_operation.py | 2 +- python/cutlass/emit/common.py | 12 +- python/cutlass/emit/pytorch.py | 20 +- python/cutlass/library_defaults.py | 4 +- python/cutlass/op/gemm.py | 2 +- python/cutlass_library/gemm_operation.py | 11 +- python/cutlass_library/generator.py | 3 +- python/cutlass_library/library.py | 6 + test/python/cutlass/emit/pytorch.py | 5 +- test/unit/conv/cache_testbed_output.h | 13 +- test/unit/conv/device/CMakeLists.txt | 7 + ...nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu | 6 +- .../conv2d_fprop_with_broadcast_simt_sm80.cu | 7 +- test/unit/conv/device/conv2d_testbed.h | 16 +- .../device/conv2d_with_broadcast_testbed.h | 21 +- ...wc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu | 3 +- .../conv3d_fprop_with_broadcast_simt_sm80.cu | 7 +- test/unit/conv/device/conv3d_testbed.h | 22 +- .../device/conv3d_with_broadcast_testbed.h | 21 +- ...m_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu | 139 +++ .../deconv2d_with_broadcast_simt_sm80.cu | 173 +++ ...32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu | 141 +++ .../deconv3d_with_broadcast_simt_sm80.cu | 172 +++ test/unit/conv/device_3x/CMakeLists.txt | 1 + test/unit/conv/device_3x/testbed_conv.hpp | 24 +- test/unit/core/CMakeLists.txt | 27 - test/unit/core/cpp11.cu | 87 -- test/unit/core/numeric_conversion.cu | 16 +- test/unit/cute/CMakeLists.txt | 3 + test/unit/cute/ampere/CMakeLists.txt | 1 + test/unit/cute/ampere/cooperative_gemm.cu | 300 ++++++ test/unit/cute/cooperative_gemm_common.hpp | 414 ++++++++ test/unit/cute/core/complement.cpp | 35 +- test/unit/cute/core/composition.cpp | 39 +- test/unit/cute/core/logical_divide.cpp | 51 +- test/unit/cute/hopper/CMakeLists.txt | 5 + test/unit/cute/hopper/tma_load.cu | 2 - test/unit/cute/hopper/tma_mcast_load.cu | 76 ++ .../cute/hopper/tma_mcast_load_testbed.hpp | 242 +++++ test/unit/cute/turing/CMakeLists.txt | 32 + test/unit/cute/turing/cooperative_gemm.cu | 58 ++ test/unit/cute/volta/CMakeLists.txt | 1 + test/unit/cute/volta/cooperative_copy.cu | 48 + test/unit/cute/volta/cooperative_gemm.cu | 421 ++++++++ .../linear_combination_planar_complex.cu | 2 +- test/unit/epilogue/threadblock/testbed.h | 13 +- test/unit/gemm/device/gemm_testbed_3x.hpp | 1 - .../sm90_gemm_f32_f32_f32_tensor_op_f32.cu | 44 + .../sm90_gemm_f8_f8_f8_tensor_op_fp32.cu | 11 +- test/unit/nvrtc/thread/testbed.h | 2 +- .../include/cutlass/library/operation_table.h | 44 +- tools/library/src/gemm_operation_3x.hpp | 3 +- tools/library/src/library_internal.h | 1 + tools/library/src/util.cu | 2 + .../profiler/gemm_operation_profiler.h | 4 +- tools/profiler/src/gemm_operation_profiler.cu | 19 +- .../util/include/cutlass/util/print_error.hpp | 77 +- .../util/reference/device/tensor_fill.h | 2 +- .../cutlass/util/reference/host/conv.hpp | 17 +- .../cutlass/util/reference/host/convolution.h | 23 +- .../cutlass/util/reference/host/gett.hpp | 4 +- 171 files changed, 7526 insertions(+), 1888 deletions(-) create mode 100644 include/cute/arch/copy_sm50.hpp create mode 100644 include/cute/atom/copy_traits_sm50.hpp create mode 100644 include/cutlass/conv/kernel/default_deconv2d.h create mode 100644 include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h create mode 100644 include/cutlass/conv/kernel/default_deconv3d.h create mode 100644 include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h create mode 100644 include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h rename include/cutlass/epilogue/threadblock/{shared_load_iterator_pitch_liner.h => shared_load_iterator_pitch_linear.h} (98%) create mode 100644 test/unit/conv/device/deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu create mode 100644 test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu create mode 100644 test/unit/conv/device/deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu create mode 100644 test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu delete mode 100644 test/unit/core/cpp11.cu create mode 100644 test/unit/cute/ampere/cooperative_gemm.cu create mode 100644 test/unit/cute/cooperative_gemm_common.hpp create mode 100644 test/unit/cute/hopper/tma_mcast_load.cu create mode 100644 test/unit/cute/hopper/tma_mcast_load_testbed.hpp create mode 100644 test/unit/cute/turing/CMakeLists.txt create mode 100644 test/unit/cute/turing/cooperative_gemm.cu create mode 100644 test/unit/cute/volta/cooperative_gemm.cu diff --git a/CHANGELOG.md b/CHANGELOG.md index 480b0dfc..245527e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # NVIDIA CUTLASS Changelog -## 3.5 (2024-03-18) +## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). @@ -12,8 +12,13 @@ - [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. +- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. + + [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934). + + [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564). - Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). - Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). +- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. +- Fixes to greatly reduce build warnings. - Updates and bugfixes from the community (thanks!) ## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14) diff --git a/CMakeLists.txt b/CMakeLists.txt index dd06a605..4933b6c9 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ else() endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") +set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++17 if set") # To reduce duplicate version locations, parse the version out of the # main versions.h file and reuse it here. @@ -332,6 +332,18 @@ endif() +# Warnings-as-error exceptions and warning suppressions for Clang builds +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=implicit-int-conversion ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=implicit-int-conversion" ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pass-failed ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=pass-failed" ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=inconsistent-missing-override ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=inconsistent-missing-override" ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-sign-conversion" ) +endif() + if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") @@ -357,6 +369,7 @@ if (CUTLASS_ENABLE_OPENMP_TESTS) message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.") endif() endif() + if(UNIX) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing) diff --git a/README.md b/README.md index 98ddbb01..865ffb76 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # CUTLASS 3.5 -_CUTLASS 3.5 - March 2024_ +_CUTLASS 3.5 - April 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -45,18 +45,21 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im CUTLASS 3.5 is an update to CUTLASS adding: -- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) +- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp). + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp). - + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms. + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! - Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. -- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x +- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x. + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. +- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. - Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). - Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). +- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. +- Fixes to greatly reduce build warnings. - Updates and bugfixes from the community (thanks!) Minimum requirements: diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index 7a800c0b..c0395f58 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -265,6 +265,10 @@ constexpr int NumStages = 3; // Which iterator algorithm to use: Analytic or Optimized static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; +// Is the output packed or strided +// Use kStride if using strided output +static cutlass::conv::StrideSupport const OutputStride = cutlass::conv::StrideSupport::kUnity; + // The epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // Data type of output matrix. @@ -289,7 +293,8 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< SwizzleThreadBlock, NumStages, cutlass::arch::OpMultiplyAdd, - IteratorAlgorithm + IteratorAlgorithm, + OutputStride >::Kernel; // Type of the actual kernel diff --git a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu index 2c3c9011..1ecd38ee 100644 --- a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu +++ b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu @@ -36,7 +36,7 @@ implicitly to tf32 inside the GEMM kernel which means no change is needed to acc fp32 data by using NVIDIA Ampere architecture. We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated -using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). +using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). The trick is very simple a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big @@ -45,11 +45,11 @@ The trick is very simple a_small x b_small is discarded because they are too small. -This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 +This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 results (SGEMM using SIMT) and against FP64 results (DGEMM) -To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -OpMultiplyAddFastF32. +To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to +OpMultiplyAddFastF32. Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference @@ -97,14 +97,14 @@ struct Result { double l2_norm_fp32_vs_fp64; // ctor - Result( + Result( int m, int n, int k, double runtime_ms, double gflops, double l2_norm_3xtf32_vs_fp64, double l2_norm_1xtf32_vs_fp64, - double l2_norm_fp32_vs_fp64) : + double l2_norm_fp32_vs_fp64) : m(m), n(n), k(k), - runtime_ms(runtime_ms), gflops(gflops), + runtime_ms(runtime_ms), gflops(gflops), l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} @@ -147,7 +147,7 @@ struct Options { int iterations; int seed; bool benchmark; - + Options(): help(false), problem_size({3456, 4096, 4096}), @@ -190,7 +190,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -227,9 +227,9 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = problem_size.product(); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -272,10 +272,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< // Number of pipelines you want to use constexpr int NumStages = 3; -// Alignment +// Alignment constexpr int Alignment = 4; -// +// // Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) // @@ -296,7 +296,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::Gemm< EpilogueOp, SwizzleThreadBlock, NumStages, - Alignment, + Alignment, Alignment, false, cutlass::arch::OpMultiplyAddFastF32>; @@ -318,7 +318,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::Gemm< EpilogueOp, SwizzleThreadBlock, NumStages, - Alignment, + Alignment, Alignment, false, cutlass::arch::OpMultiplyAdd>; @@ -356,7 +356,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -397,7 +397,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -411,7 +411,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Gemm output (D) for GEMM_F64 cutlass::HostTensor tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N // Gemm output (D) for GEMM_3xTF32 @@ -426,7 +426,7 @@ bool run(Options &options) { cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -464,7 +464,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_3xTF32 gemm_op_3xTF32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -568,7 +568,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_1xTF32 gemm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -627,7 +627,7 @@ bool run(Options &options) { tensor_d_F32.sync_host(); //////////////////////////////////////////////////////////////////////////////// - /////// Compute l2 norms + /////// Compute l2 norms //////////////////////////////////////////////////////////////////////////////// // l2 norm 3xTF32 vs F64 @@ -664,7 +664,7 @@ bool run(Options &options) { std::cout << "GFLOPs: " << result.gflops << std::endl; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; @@ -673,11 +673,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -690,7 +690,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { @@ -716,17 +716,17 @@ int main(int argc, const char **argv) { if (options.benchmark) { for (int k = 4; k <= 65536; k *= 2) { - + options.problem_size[2] = k; - + printf("Gemm problem size: %d x %d x %d\n", \ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); - + if (!options.valid()) { std::cerr << "Invalid problem." << std::endl; return -1; } - + result &= run(options); } } else { diff --git a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu index c4e6e958..18375f6d 100644 --- a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu +++ b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu @@ -34,7 +34,7 @@ difference is that this example uses 3xtf32 on complex gemm. To enable this feature, the only change needs to make is to change OpMultiplyAddComplex - to OpMultiplyAddComplexFastF32. + to OpMultiplyAddComplexFastF32. */ #include @@ -74,14 +74,14 @@ struct Result { double l2_norm_fp32_vs_fp64; // ctor - Result( + Result( int m, int n, int k, double runtime_ms, double gflops, double l2_norm_3xtf32_vs_fp64, double l2_norm_1xtf32_vs_fp64, - double l2_norm_fp32_vs_fp64) : + double l2_norm_fp32_vs_fp64) : m(m), n(n), k(k), - runtime_ms(runtime_ms), gflops(gflops), + runtime_ms(runtime_ms), gflops(gflops), l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} @@ -124,7 +124,7 @@ struct Options { int iterations; int seed; bool benchmark; - + Options(): help(false), problem_size({3456, 4096, 4096}), @@ -153,7 +153,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -190,9 +190,9 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = problem_size.product(); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -239,7 +239,7 @@ constexpr int NumStages = 3; constexpr cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone; constexpr cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone; -// +// // Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) // @@ -260,7 +260,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::GemmComplex< EpilogueOp, SwizzleThreadBlock, NumStages, - TransformA, + TransformA, TransformB, cutlass::arch::OpMultiplyAddComplexFastF32>; @@ -281,7 +281,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::GemmComplex< EpilogueOp, SwizzleThreadBlock, NumStages, - TransformA, + TransformA, TransformB, cutlass::arch::OpMultiplyAddComplex>; @@ -296,7 +296,7 @@ bool run(Options &options) { cutlass::HostTensor, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -337,7 +337,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -351,7 +351,7 @@ bool run(Options &options) { cutlass::HostTensor, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Gemm output (D) for GEMM_F64 cutlass::HostTensor, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N // Gemm output (D) for GEMM_3xTF32 @@ -366,7 +366,7 @@ bool run(Options &options) { cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -404,7 +404,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_3xTF32 gemm_op; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = gemm_op.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -508,7 +508,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_1xTF32 gemm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -569,7 +569,7 @@ bool run(Options &options) { tensor_d_F32.sync_host(); //////////////////////////////////////////////////////////////////////////////// - /////// Compute l2 norms + /////// Compute l2 norms //////////////////////////////////////////////////////////////////////////////// // l2 norm 3xTF32 vs F64 @@ -606,7 +606,7 @@ bool run(Options &options) { std::cout << "GFLOPs: " << result.gflops << std::endl; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; @@ -615,11 +615,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -632,7 +632,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { @@ -658,17 +658,17 @@ int main(int argc, const char **argv) { if (options.benchmark) { for (int k = 4; k <= 65536; k *= 2) { - + options.problem_size[2] = k; - + printf("Gemm problem size: %d x %d x %d\n", \ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); - + if (!options.valid()) { std::cerr << "Invalid problem." << std::endl; return -1; } - + result &= run(options); } } else { diff --git a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu index 098ca8a2..4863ed93 100644 --- a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu +++ b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu @@ -36,7 +36,7 @@ implicitly to tf32 inside the SYMM kernel which means no change is needed to acc F32 data by using NVIDIA Ampere architecture. We can use the tf32 mode of tensor core to emulate a fast accurate SYMM kernel which is accelerated -using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). +using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). The trick is very simple a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big @@ -45,11 +45,11 @@ The trick is very simple a_small x b_small is discarded because they are too small. -This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 +This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 results (SSYMM from cuBLAS) and against F64 results (DSYMM from CUTLASS) -To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -OpMultiplyAddFastF32. +To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to +OpMultiplyAddFastF32. Now, we have two different flavors of SSYMM in the profiler for Ampere: @@ -95,7 +95,7 @@ struct Options { float beta; std::string rand_mode; int seed; - + Options(): help(false), problem_size({4096, 4096, 4096}), @@ -137,7 +137,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -207,10 +207,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< // Number of pipelines you want to use constexpr int NumStages = 3; -// Alignment +// Alignment constexpr int Alignment = 4; -// +// // CUTLASS Symm Operators (SSYM: Symm_3xTF32, Symm_1xTF32, DSYMM: Symm_F64) // @@ -233,7 +233,7 @@ using Symm_3xTF32 = cutlass::gemm::device::Symm< EpilogueOp, SwizzleThreadBlock, NumStages, - 1, // Symmetric matrix is always align 1 + 1, // Symmetric matrix is always align 1 Alignment, false, cutlass::arch::OpMultiplyAddFastF32>; @@ -257,7 +257,7 @@ using Symm_1xTF32 = cutlass::gemm::device::Symm< EpilogueOp, SwizzleThreadBlock, NumStages, - 1, // Symmetric matrix is always align 1 + 1, // Symmetric matrix is always align 1 Alignment, false, cutlass::arch::OpMultiplyAdd>; @@ -298,7 +298,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -339,7 +339,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -353,7 +353,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Symm output (D) for SYMM_3xTF32 cutlass::HostTensor tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N // Symm output (D) for SYMM_1xTF32 @@ -375,7 +375,7 @@ bool run(Options &options) { #if CUTLASS_ENABLE_CUBLAS cutlass::reference::host::TensorCopy(tensor_d_cublasF32.host_view(), tensor_d_F32.host_view()); #endif - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -430,7 +430,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_3xTF32 symm_op_3xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = symm_op_3xtf32.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -477,7 +477,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_1xTF32 symm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = symm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -524,7 +524,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_F64 symm_op_f64; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_f64 = symm_op_f64.can_implement(arguments_f64); CUTLASS_CHECK(status_f64); @@ -568,7 +568,7 @@ bool run(Options &options) { static_cast(&beta), static_cast(tensor_d_cublasF32.device_data()), int(tensor_d_cublasF32.layout().stride(0)) - ); + ); cudaDeviceSynchronize(); @@ -576,7 +576,7 @@ bool run(Options &options) { #endif //////////////////////////////////////////////////////////////////////////////// - /// 7. Compute l2 norms + /// 7. Compute l2 norms //////////////////////////////////////////////////////////////////////////////// #if CUTLASS_ENABLE_CUBLAS @@ -605,20 +605,20 @@ bool run(Options &options) { double l2_norm_3xtf32_vs_cublasf32 = cutlass::reference::host::TensorRelativeErrorMetric( tensor_d_3xTF32.host_view(), tensor_d_cublasF32.host_view()); #endif - + // l2 norm 3xTF32 vs 1xTF32 double l2_norm_3xtf32_vs_1xtf32 = cutlass::reference::host::TensorRelativeErrorMetric( tensor_d_3xTF32.host_view(), tensor_d_1xTF32.host_view()); /////////////////////////////////////////////////////////////////////////////// - // Print kernel info and L2 norms + // Print kernel info and L2 norms std::cout << "Problem Size: (" << problem_size.m() << "," << problem_size.n() << "," << problem_size.k() << ") " << "Alpha: " << alpha << "," << " Beta: " << beta << std::endl; std::cout << std::fixed; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific #if CUTLASS_ENABLE_CUBLAS << " - cuBLAS F32 error with F64 reference : " << l2_norm_cublasf32_vs_f64 << std::endl #endif @@ -633,11 +633,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -650,7 +650,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index c0337aba..78730710 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -71,85 +71,103 @@ cooperative_copy(uint32_t const& tid, // Precondition on tid in DEBUG assert(tid < NumThreads); - // Precondition on pointer alignment in DEBUG - assert(is_byte_aligned(raw_pointer_cast(src.data()))); - assert(is_byte_aligned(raw_pointer_cast(dst.data()))); - // - // Determine val+thr vectorization based on src/dst size and number of threads - // NOTE: This heuristic promotes parallelization over vectorization - // - constexpr int elem_bits = sizeof_bits_v; + // Fallback - slow path, naive copy, vectorization disabled + if constexpr(size(SrcLayout{}) % NumThreads != 0) { + int index = static_cast(tid); + CUTE_UNROLL + for(int i = 0; i < ceil_div(size(SrcLayout{}), NumThreads); i++) { + if(index < size(SrcLayout{})) { + dst[index] = src[index]; + } + index += NumThreads; + } + } else { + // Fast path with vectorization - // The number of elements that can be vectorized in values - constexpr int common_elem = decltype(max_common_vector(src, dst))::value; - constexpr int common_bits = common_elem * elem_bits; - constexpr int total_elem = decltype(size(src))::value; - constexpr int total_bits = total_elem * elem_bits; - static_assert(total_bits % NumThreads == 0); - constexpr int total_bits_per_thr = total_bits / NumThreads; - // If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits - constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr); + // Precondition on pointer alignment in DEBUG + assert(is_byte_aligned(raw_pointer_cast(src.data()))); + assert(is_byte_aligned(raw_pointer_cast(dst.data()))); + constexpr int elem_bits = sizeof_bits_v; - // Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits - constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast(MaxVecBits)); - // Convert back to number of elements, safe_div - static_assert((vec_bits % elem_bits) == 0); - constexpr int vec_elem = vec_bits / elem_bits; + // + // Determine val+thr vectorization based on src/dst size and number of threads + // NOTE: This heuristic promotes parallelization over vectorization + // - // Use only part of threads if there's not enough work for all threads - constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0) - ? NumThreads - : (total_elem / vec_elem); + // The number of elements that can be vectorized in values + constexpr int common_elem = decltype(max_common_vector(src, dst))::value; + constexpr int common_bits = common_elem * elem_bits; + constexpr int total_elem = decltype(size(src))::value; + constexpr int total_bits = total_elem * elem_bits; + static_assert(total_bits % NumThreads == 0); + constexpr int total_bits_per_thr = total_bits / NumThreads; + // If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits + constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr); - // The common layout of the two tensors that can be vectorized over threads - // vidx -> coord - auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()), - get_nonswizzle_portion(dst.layout())); + // Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits + constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast(MaxVecBits)); + // Convert back to number of elements, safe_div + static_assert((vec_bits % elem_bits) == 0); + constexpr int vec_elem = vec_bits / elem_bits; - // Scale up the common_layout to cover the entire tensors - // vidx -> coord - auto full_perm = tile_to_shape(make_layout(common_layout), size(src)); + // Use only part of threads if there's not enough work for all threads + constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0) + ? NumThreads + : (total_elem / vec_elem); + static_assert(vec_thrs <= NumThreads); - // Create the Tiler - // ((vid,tid),iter) - auto layout_vt = logical_divide(full_perm, Layout, Int>>{}); + // The common layout of the two tensors that can be vectorized over threads + // vidx -> coord + auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()), + get_nonswizzle_portion(dst.layout())); - // Apply and slice - Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_); - Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_); + // Scale up the common_layout to cover the entire tensors + // vidx -> coord + auto full_perm = tile_to_shape(make_layout(common_layout), size(src)); - // Should account for vec_bits < 8 and/or vec_elem <= 1 - // And also account for subbyte types, which could cause race conditions - // Want to ENFORCE sufficient vectorization in those cases - static_assert((vec_bits >= 8), "No support for subbyte copying"); - using VecType = uint_bit_t; + // Create the Tiler + // ((vid,tid),iter) + auto layout_vt = logical_divide(full_perm, Layout, Int>>{}); + + // Apply and slice + Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_); + Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_); + + // Should account for vec_bits < 8 and/or vec_elem <= 1 + // And also account for subbyte types, which could cause race conditions + // Want to ENFORCE sufficient vectorization in those cases + static_assert((vec_bits >= 8), "No support for subbyte copying"); + using VecType = uint_bit_t; #if 0 - if (thread0()) { - print(" "); print("NumThreads: "); print(NumThreads); print("\n"); - print(" "); print("src: "); print(src); print("\n"); - print(" "); print("dst: "); print(dst); print("\n"); - print(" "); print("common_layout: "); print(common_layout); print("\n"); - print(" "); print("full_perm: "); print(full_perm); print("\n"); - print(" "); print("Used vector: "); print(vec_elem); print("\n"); - print(" "); print("Used threads: "); print(vec_thrs); print("\n"); - print(" "); print("layout_vt: "); print(layout_vt); print("\n"); - print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n"); - print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n"); - print(" "); print("src_v: "); print(src_v); print("\n"); - print(" "); print("dst_v: "); print(dst_v); print("\n"); - print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); - print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); - } + if (thread0()) { + print(" "); print("cooperative_copy -- vec\n"); + print(" "); print("NumThreads: "); print(NumThreads); print("\n"); + print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n"); + print(" "); print("src: "); print(src); print("\n"); + print(" "); print("dst: "); print(dst); print("\n"); + print(" "); print("common_layout: "); print(common_layout); print("\n"); + print(" "); print("full_perm: "); print(full_perm); print("\n"); + print(" "); print("Used vector: "); print(vec_elem); print("\n"); + print(" "); print("Used threads: "); print(vec_thrs); print("\n"); + print(" "); print("layout_vt: "); print(layout_vt); print("\n"); + print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n"); + print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); + print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); + } #ifdef __CUDA_ARCH__ - __syncthreads(); + __syncthreads(); #endif #endif - // If we're using all threads (static) or the tid is in in-range (dynamic) - if (vec_thrs >= NumThreads or tid < vec_thrs) { - return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); + // If we're using all threads (static) or the tid is in in-range (dynamic) + if (vec_thrs >= NumThreads or tid < vec_thrs) { + return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); + } } } diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index 32cec54b..b8388159 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -35,6 +35,7 @@ #include +#include #include #include @@ -44,40 +45,37 @@ namespace cute { // -// Collective Shared-Memory GEMMs +// Cooperative Shared-Memory GEMMs // +namespace detail { + +// Predicated Cooperative GEMM template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void -cooperative_gemm(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, - BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) +cooperative_gemm_predication(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C { - CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM - CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN - CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK - using TypeA = typename TA::value_type; using TypeB = typename TB::value_type; using TypeC = typename TC::value_type; - static_assert(is_same_v>, TypeA>, - "ALoadTransformOp functor must accept and return value of type TA::value_type"); - static_assert(is_same_v>, TypeB>, - "BLoadTransformOp functor must accept and return value of type TB::value_type"); - // Original, static size of the problem auto M = size<0>(sC); auto N = size<1>(sC); @@ -88,39 +86,14 @@ cooperative_gemm(ThrMMA const& thr_mma, auto BLK_N = tile_size<1>(thr_mma); auto BLK_K = tile_size<2>(thr_mma); - // Compute the "residues" - auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M] - auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N] - auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0] - - // Shift the origin so k_residue is zeroth tile - sA.data() = &sA(0,k_residue); - sB.data() = &sB(0,k_residue); - -#if 0 - if (thread0()) { - printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M)); - printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N)); - printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K)); - } -#endif - // // MMA Partitioning // - // Round the layout extents up to BLK_X - Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K)); - Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K)); - Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N)); - -#if 0 - if (thread0()) { - print("rounded_sA: "); print(rounded_sA); print("\n"); - print("rounded_sB: "); print(rounded_sB); print("\n"); - print("rounded_sC: "); print(rounded_sC); print("\n"); - } -#endif + // Round the layout extents up to BLK_X to satisfy MMA partitioning safety + Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K))); + Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K))); + Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N))); // Partition the sA and sB tiles across the threads for the MMA Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K) @@ -133,6 +106,13 @@ cooperative_gemm(ThrMMA const& thr_mma, #if 0 if (thread0()) { + print(" sA: "); print( sA); print("\n"); + print(" sB: "); print( sB); print("\n"); + print(" sC: "); print( sC); print("\n"); + print("r_sA: "); print(rounded_sA); print("\n"); + print("r_sB: "); print(rounded_sB); print("\n"); + print("r_sC: "); print(rounded_sC); print("\n"); + print(thr_mma); print("tCsA: "); print(tCsA); print("\n"); print("tCsB: "); print(tCsB); print("\n"); print("tCsC: "); print(tCsC); print("\n"); @@ -146,58 +126,232 @@ cooperative_gemm(ThrMMA const& thr_mma, // PREDICATION // - // Allocate the preds for only the MMA-mode of tCsA and tCsB - Tensor tCpA = make_tensor(size<0>(tCsA)); - Tensor tCpB = make_tensor(size<0>(tCsB)); - - // Create coordinate tensors on a single compute block for predication - Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k) + // Create coordinate tensors for the problem + Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k) + Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k) // Repeat partitioning with thr_mma - Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k) - Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k) + Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k) + Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k) - // Populate the m and n predicates + // Allocate the preds for MMA- and MMA_MN-modes + Tensor tCpA = make_tensor(make_shape(size<0>(tCsA), size<1>(tCsA))); + Tensor tCpB = make_tensor(make_shape(size<0>(tCsB), size<1>(tCsB))); + + // Populate the predicates on M and N CUTE_UNROLL for (int i = 0; i < size(tCpA); ++i) { - tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue); + tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA)); } CUTE_UNROLL for (int i = 0; i < size(tCpB); ++i) { - tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue); + tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB)); } #if 0 - printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n", - threadIdx.x, - int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)), - int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0))); + if (thread0()) { + print(" cA: "); print( cA); print("\n"); + print(" cB: "); print( cB); print("\n"); + print("tCcA: "); print(tCcA); print("\n"); + print("tCcB: "); print(tCcB); print("\n"); + print_tensor(tCpA); + print_tensor(tCpB); + } #endif // - // PREFETCH k_block = 0 (with k-predication) + // PREFETCH k_block = 0 + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial // - CUTE_UNROLL - for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I - if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k - CUTE_UNROLL - for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m - tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; - } - } - } + constexpr int K_BLOCK_MAX = size<2>(tCrA); CUTE_UNROLL - for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I - if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k - CUTE_UNROLL - for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n - tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; - } + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; } } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; + } + } + // + // MAINLOOP + // + + // Clear accumulators + clear(tCrC); + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block + { + int k_next = k_block + 1; // Load k_next block + + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; + } + } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; + } + } + } + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } + + // + // Epilogue + // + + // Create coordinate tensors for the problem + Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n) + // Repeat partitioning with thr_mma + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) + + const bool isBetaZero = (beta == Beta{}); + + // Custom axpby_if for now + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) + { + if (elem_less(tCcC(i), shape(sC))) + { + tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast(tCrC(i)) + : alpha * static_cast(tCrC(i)) + + beta * static_cast(sC_load_op(tCsC(i)))); + } + } +} + +// Slow fallback path +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +cooperative_gemm_predication(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op); +} + +// Unpredicated Cooperative GEMM +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +cooperative_gemm_no_predication(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + using TypeA = typename TA::value_type; + using TypeB = typename TB::value_type; + using TypeC = typename TC::value_type; + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + // + // MMA Partitioning + // + + Tensor tCsC = thr_mma.partition_C(sC); + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + + using CopyOpAType = SmemCopyOpA; + using CopyOpBType = SmemCopyOpB; + + auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + +#if 0 + if (thread0()) { + print(" sA: "); print(sA); print("\n"); + print(" sB: "); print(sB); print("\n"); + print(" sC: "); print(sC); print("\n"); + print(thr_mma); print("\n"); + print("tCsC: "); print(tCsC); print("\n"); + print("tCrA: "); print(tCrA); print("\n"); + print("tCrB: "); print(tCrB); print("\n"); + print("tCrC: "); print(tCrC); print("\n"); + print(smem_thr_copy_A); print("\n"); + print("tCsA: "); print(tCsA); print("\n"); + print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n"); + print(smem_thr_copy_B); print("\n"); + print("tCsB: "); print(tCsB); print("\n"); + print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n"); + } +#endif + + // + // PREFETCH + // + + copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); // // MAINLOOP // @@ -214,25 +368,15 @@ cooperative_gemm(ThrMMA const& thr_mma, if (k_block < K_BLOCK_MAX-1) { // Load the next k_block - int k_next = k_block + 1; - - CUTE_UNROLL - for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M - CUTE_UNROLL - for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m - tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; - } - } - - CUTE_UNROLL - for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N - CUTE_UNROLL - for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n - tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; - } - } + int k_next = k_block + 1; // statically unrolled + copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next)); + copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next)); } + // Transform A and B, relying on the compiler to remove in case of identity ops + cute::transform(tCrA(_,_,k_block), sA_load_op); + cute::transform(tCrB(_,_,k_block), sB_load_op); + // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } @@ -241,53 +385,124 @@ cooperative_gemm(ThrMMA const& thr_mma, // Epilogue // - Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n) - - const bool isBetaZero = (beta == Beta{}); - - // Custom axpby_if for now - CUTE_UNROLL - for (int m = 0; m < size<1>(tCsC); ++m) - { - CUTE_UNROLL - for (int n = 0; n < size<2>(tCsC); ++n) - { - CUTE_UNROLL - for (int i = 0; i < size<0>(tCsC); ++i) - { - if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) && - (n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue)) - { - tCsC(i,m,n) = isBetaZero ? alpha * static_cast(tCrC(i,m,n)) : alpha * static_cast(tCrC(i,m,n)) + beta * static_cast(tCsC(i,m,n)); - } - } + auto isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + using CopyOpCType = SmemCopyOpC; + Tensor tCrD = thr_mma.make_fragment_C(tCsC); + if(!isBetaZero) { + copy(CopyOpCType{}, tCsC, tCrD); + // Transform C on/after load + cute::transform(tCrD, sC_load_op); + } + // C = alpha * (A * B) + beta * C + axpby(alpha, tCrC, beta, tCrD); + // Transform C before/on store + cute::transform(tCrD, sC_store_op); + copy(CopyOpCType{}, tCrD, tCsC); +} + +} // end namespace detail + +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C +{ + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using TypeA = typename TA::value_type; + using TypeB = typename TB::value_type; + using TypeC = typename TC::value_type; + + static_assert(is_convertible_v>, TypeA>, + "ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type"); + static_assert(is_convertible_v>, TypeB>, + "BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type"); + static_assert(is_convertible_v>, TypeC>, + "CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); + static_assert(is_convertible_v>, TypeC>, + "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); + + static constexpr bool compat = weakly_compatible(tile_shape(TiledMMA{}), + make_shape(size<0>(sA), size<0>(sB), size<1>(sA))); + if constexpr (compat) { + detail::cooperative_gemm_no_predication( + thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); + } else { + detail::cooperative_gemm_predication( + thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); } } template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void -cooperative_gemm(ThrMMA const& thr_mma, +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, - Tensor sC) + Tensor sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { - cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); + using CopyOpA = AutoVectorizingCopyWithAssumedAlignment>; + using CopyOpB = AutoVectorizingCopyWithAssumedAlignment>; + using CopyOpC = AutoVectorizingCopyWithAssumedAlignment>; + cooperative_gemm( + thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); } +// Legacy overload of cute::gemm for backwards-compatibility template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> @@ -299,28 +514,16 @@ gemm(ThrMMA const& thr_mma, Tensor sB, Beta const& beta, Tensor sC, - ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, - BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { - cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op); -} - -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> -CUTE_HOST_DEVICE -void -gemm(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC) -{ - cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); + // Goes directly to the slow path to avoid getting thread_idx from thr_mma + detail::cooperative_gemm_predication( + thr_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); } } // end namespace cute diff --git a/include/cute/arch/copy_sm50.hpp b/include/cute/arch/copy_sm50.hpp new file mode 100644 index 00000000..9cf0efcd --- /dev/null +++ b/include/cute/arch/copy_sm50.hpp @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 500 + #define CUTE_ARCH_WARP_SHUFFLE_ENABLED 1 +#endif + +namespace cute +{ + +struct SM50_Shuffle_U32_2x2Trans +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = src0; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 1); + + uint32_t x1 = src1; + uint32_t y1 = __shfl_xor_sync(0xffffffff, x1, 1); + + if (threadIdx.x % 2 == 0) { + dst1 = y0; + } + else { + dst0 = y1; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + + +} // end namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index 06add577..92e34251 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -117,7 +117,7 @@ cast_smem_ptr_to_uint(void const* const ptr) uint32_t smem_ptr; asm( - "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" : "=r"(smem_ptr) : "l"(ptr)); return smem_ptr; @@ -132,11 +132,47 @@ cast_smem_ptr_to_uint(void const* const ptr) #endif } +namespace detail { + // -// Utility for pointer interfaces +// Wrapper for MMAOp::fma // -namespace detail { +template +struct CallFMA { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return MmaOp::fma(static_cast(args)...); + } +}; + +// +// Wrapper for CopyOp::copy +// + +template +struct CallCOPY { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return CopyOp::copy(static_cast(args)...); + } +}; + +// +// Utility for exploding pointers/arrays/tensors into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence) +{ + return fn(a[I]...); +} template -CUTE_HOST_DEVICE constexpr -void -explode_with_d_scaling(Fn fn, - PtrA&& a, int_sequence, - PtrB&& b, int_sequence, - PtrC&& c, int_sequence, - ParamType&& p0) -{ - return fn(a[Ia]..., b[Ib]..., c[Ic]..., p0); -} - template + class PtrE, int... Ie> CUTE_HOST_DEVICE constexpr void -explode_with_d_scaling(Fn fn, +explode(Fn fn, PtrD&& d, int_sequence, PtrA&& a, int_sequence, PtrB&& b, int_sequence, PtrC&& c, int_sequence, - ParamType&& p0) + PtrE&& e, int_sequence) { - return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., p0); + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrSFA&& sfa, int_sequence, + PtrSFB&& sfb, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., sfa[Isfa]..., sfb[Isfb]...); +} +// +// Utility for exploding tuples into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence) +{ + return fn(get(a)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence) +{ + return fn(get(a)..., get(b)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence, + TupleC&& c, int_sequence) +{ + return fn(get(a)..., get(b)..., get(c)...); } } // end namespace detail -template -CUTE_HOST_DEVICE constexpr -void -explode(Fn fn, PtrS&& s, PtrD&& d) -{ - return detail::explode(fn, - s, make_int_sequence{}, - d, make_int_sequence{}); -} - -template -CUTE_HOST_DEVICE constexpr -void -explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c) -{ - return detail::explode(fn, - a, make_int_sequence{}, - b, make_int_sequence{}, - c, make_int_sequence{}); -} - -template -CUTE_HOST_DEVICE constexpr -void -explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c) -{ - return detail::explode(fn, - d, make_int_sequence{}, - a, make_int_sequence{}, - b, make_int_sequence{}, - c, make_int_sequence{}); -} - } // end namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index d1cd3d4b..48a5fd16 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -756,6 +756,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and //////////////////////////////////////////////////////////////////////////////////////////////////// +#include #include #include #include diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index b6259b58..2aa3ba57 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -92,59 +92,6 @@ struct Copy_Traits> using RefLayout = SrcLayout; }; -namespace detail { - -// Utility for exploding pointers, arrays, or tensors into Operation::copy -template -CUTE_HOST_DEVICE constexpr -void -copy_explode_index(PtrSrc&& s, int_sequence, - PtrDst&& d, int_sequence) -{ - return Operation::copy(s[Is]..., d[Id]...); -} - -// Utility for exploding tuples into ::copy -template -CUTE_HOST_DEVICE constexpr -void -copy_explode(TupleArg&& t, int_sequence) -{ - return Operation::copy(get(static_cast(t))...); -} - -template -CUTE_HOST_DEVICE constexpr -void -copy_explode(TupleSrc&& s, int_sequence, - TupleDst&& d, int_sequence) -{ - return Operation::copy(get(static_cast(s))..., - get(static_cast(d))...); -} - -template -CUTE_HOST_DEVICE constexpr -void -copy_explode(TupleAux&& a, int_sequence, - TupleSrc&& s, int_sequence, - TupleDst&& d, int_sequence) -{ - return Operation::copy(get(static_cast(a))..., - get(static_cast(s))..., - get(static_cast(d))...); -} - -} // end namespace detail - // // Generic copy_unpack for common argument-based Copy_Traits // @@ -177,8 +124,9 @@ copy_unpack(Copy_Traits const&, CUTE_STATIC_ASSERT_V(size(rD) == Int{}, "Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp."); - detail::copy_explode_index(rS, make_int_sequence{}, - rD, make_int_sequence{}); + detail::explode(detail::CallCOPY{}, + rS, make_int_sequence{}, + rD, make_int_sequence{}); } // diff --git a/include/cute/atom/copy_traits_sm50.hpp b/include/cute/atom/copy_traits_sm50.hpp new file mode 100644 index 00000000..8be0ef7b --- /dev/null +++ b/include/cute/atom/copy_traits_sm50.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1, _64>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index 34e71ed6..15d9979c 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -73,14 +73,16 @@ struct TMA_LOAD_IM2COL_Unpack CUTE_STATIC_ASSERT_V(rank<1>(src_coord_offset) == rank<3>(src_coord_offset)); if constexpr (detail::is_prefetch) { - return detail::copy_explode(traits.opargs_, tuple_seq{}, - src_coord_cwhdn_offset_srt, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); } else { static_assert(is_smem::value, "SM90_TMA_LOAD_IM2COL requires the destination be shared memory."); void* dst_ptr = cute::raw_pointer_cast(dst.data()); - return detail::copy_explode(traits.opargs_, tuple_seq{}, - make_tuple(dst_ptr), seq<0>{}, - src_coord_cwhdn_offset_srt, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); } } }; @@ -349,8 +351,9 @@ struct Copy_Traits void const* const src_ptr = cute::raw_pointer_cast(src.data()); auto dst_coord = flatten(take<0,3>(dst(Int<0>{}))); - return detail::copy_explode(make_tuple(desc_ptr, src_ptr), seq<0,1>{}, - dst_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); } }; @@ -537,6 +540,133 @@ make_im2col_tma_copy_desc( return cute::make_tuple(tma_desc, tma_tensor); } +template +CUTE_HOST_RTC +auto +make_tma_atom_im2col(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor: ((w, h, d, n), c) + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + int32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map, // V: CTA val idx -> gmem mode + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) // dilation +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // trunc_smem_idx -> trunc_smem_coord + + // Map from smem idx to a gmem mode + auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_t_map : "); print(cta_t_map); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // Generate a TupleBasis for the gtensor + auto glayout_basis = make_identity_layout(product_each(shape(gtensor))); + + // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc + auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank >= 2, "IM2COL expects at least 2 modes of the smem to vectorize with gmem."); + // IM2COL uses a maximum of 2 modes + constexpr int smem_tma_rank = cute::min(int(smem_rank), 2); + + // Keep only the static-1 basis modes into gmem + auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); + + // Split according to the portion each multicast CTA will be responsible for + auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), num_multicast)); + +#if 0 + print("glayout_basis : "); print(glayout_basis); print("\n"); + print("tma_layout_full : "); print(tma_layout_full); print("\n"); + + print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); + print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); +#endif + + auto range_c = size<0,0>(tma_layout_vt); + auto range_whdn = size<0,1>(tma_layout_vt); + + Tensor gtensor_cwhdn = make_tensor(gtensor.data(), + flatten(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.layout()), + basis_get(stride<0,1>(tma_layout_vt), gtensor.layout())))); + + auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( + gtensor_cwhdn, + range_c, + range_whdn, + detail::get_swizzle_portion(slayout), + tma_layout_vt, + lower_corner_whd, + upper_corner_whd, + lower_padding_whd, + upper_padding_whd, + stride_whd, + lower_srt, + stride_srt); + + // + // Construct the Copy_Traits + // + + using T = typename GEngine::value_type; + constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof(T) * 8; + + using Traits = Copy_Traits, decltype(tma_tensor)>; + using Atom = Copy_Atom; + +#if 0 + print("num_bits : "); print(num_bits_per_tma); print("\n"); +#endif + + Traits tma_traits{tma_desc, tma_tensor}; + + // Return the Copy_Atom + return Atom{tma_traits}; +} + /// Make a TiledCopy for im2col TMA load. /// /// @param copy_op The copy implementation: either @@ -584,99 +714,12 @@ make_tma_copy_im2col(CopyOp const& copy_op, // TMA parameter checking // - CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), - "TMA requires CTA_Tile and SLayout top-level shape equivalence."); CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, "Number of active CTAs in TMA must divide domain size of slayout."); - // - // TMA slayout manipulation - // - - // Invert the smem to get the largest contiguous vector in the smem layout - auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); - // trunc_smem_idx -> trunc_smem_coord - - // Map from smem idx to a gmem mode - auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); - -#if 0 - print("g_layout : "); print(gtensor.layout()); print("\n"); - print("s_layout : "); print(slayout); print("\n"); - print("cta_t_map : "); print(cta_t_map); print("\n"); - print("cta_v_map : "); print(cta_v_map); print("\n"); - print("inv_smem : "); print(inv_smem_layout); print("\n"); - print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); -#endif - - // - // TMA gtensor manipulation - // - - // Generate a TupleBasis for the gtensor - auto glayout_basis = make_identity_layout(product_each(shape(gtensor))); - - // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc - auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); - - // Truncate any incompatibilities -- no starting in the middle of gmodes - auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { - [[maybe_unused]] auto v = basis_value(e); - return not is_constant<1,decltype(v)>{}; - }); - static_assert(smem_rank >= 2, "IM2COL expects at least 2 modes of the smem to vectorize with gmem."); - // IM2COL uses a maximum of 2 modes - constexpr int smem_tma_rank = cute::min(int(smem_rank), 2); - - // Keep only the static-1 basis modes into gmem - auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); - - // Split according to the portion each multicast CTA will be responsible for - auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map))); - -#if 0 - print("glayout_basis : "); print(glayout_basis); print("\n"); - print("tma_layout_full : "); print(tma_layout_full); print("\n"); - - print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); - print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); -#endif - - auto range_c = size<0,0>(tma_layout_vt); - auto range_whdn = size<0,1>(tma_layout_vt); - - Tensor gtensor_cwhdn = make_tensor(gtensor.data(), - flatten(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.layout()), - basis_get(stride<0,1>(tma_layout_vt), gtensor.layout())))); - - auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( - gtensor_cwhdn, - range_c, - range_whdn, - detail::get_swizzle_portion(slayout), - tma_layout_vt, - lower_corner_whd, - upper_corner_whd, - lower_padding_whd, - upper_padding_whd, - stride_whd, - lower_srt, - stride_srt); - - // - // Construct the Copy_Traits - // - - using T = typename GEngine::value_type; - constexpr int num_bits_per_tma = decltype(size<0>(tma_layout_vt))::value * sizeof(T) * 8; - - using Traits = Copy_Traits, decltype(tma_tensor)>; - -#if 0 - print("num_bits : "); print(NumBitsPerTMA{}); print("\n"); -#endif - - Traits tma_traits{tma_desc, tma_tensor}; + Copy_Atom atom = make_tma_atom_im2col(copy_op, gtensor, slayout, cosize(cta_t_map), cta_v_map, + lower_corner_whd, upper_corner_whd, lower_padding_whd, + upper_padding_whd, stride_whd, lower_srt, stride_srt); // // Construct the TiledCopy @@ -684,25 +727,31 @@ make_tma_copy_im2col(CopyOp const& copy_op, auto cta_tiler = product_each(shape(cta_v_map)); - // (CTA V, CTA T) -> smem_coord - auto layout_vt = composition(inv_smem_layout, make_layout(shape(tma_layout_vt))); + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); // Scale that up to cover all of the smem_coords - // - // The smem vector might not cover all of the tile, - // so multiply it up to cover the entire tile. - // "T" here (the parallel index) is a CTA index. - auto layout_VT = tile_to_shape(layout_vt, make_shape(size(cta_v_map)/size<1>(layout_vt), size<1>(layout_vt))); - // Flip it and change the domain of the T from logical thr to thr_idx - auto layout_TV = make_layout(composition(layout<1>(layout_VT), cta_t_map), layout<0>(layout_VT)); + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); #if 0 print("cta_tiler : "); print(cta_tiler); print("\n"); - print("layout_VT : "); print(layout_VT); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); print("layout_TV : "); print(layout_TV); print("\n"); #endif - using T = typename GEngine::value_type; - return TiledCopy, decltype(layout_TV), decltype(cta_tiler)>{tma_traits}; + return TiledCopy{atom}; } /// Make a TiledCopy for im2col TMA with no offsets. diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 16b2a648..d42c82c9 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -69,8 +69,9 @@ struct TMA_LOAD_Unpack { auto src_coord = src.data().coord_; if constexpr (detail::is_prefetch) { - return detail::copy_explode(traits.opargs_, tuple_seq{}, - src_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); } else { static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); void* dst_ptr = cute::raw_pointer_cast(dst.data()); @@ -81,9 +82,10 @@ struct TMA_LOAD_Unpack blockIdx.x, blockIdx.y, blockIdx.z, int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); #endif - return detail::copy_explode(traits.opargs_, tuple_seq{}, - make_tuple(dst_ptr), seq<0>{}, - src_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); } } }; @@ -337,8 +339,9 @@ struct Copy_Traits blockIdx.x, blockIdx.y, blockIdx.z, int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); #endif - return detail::copy_explode(make_tuple(desc_ptr, src_ptr), seq<0,1>{}, - dst_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); } }; @@ -1278,7 +1281,7 @@ tma_partition(Copy_Atom const& copy_atom, // Factor out the single-instrucion portion Layout tma_layout_v = make_layout(Int::NumValSrc>{}); auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); - + // Append with _ until we cover all Rest... modes auto glayout_V = append>(layout_V, _); auto slayout_V = append>(layout_V, _); @@ -1288,39 +1291,45 @@ tma_partition(Copy_Atom const& copy_atom, #if 0 if (thread0()) { - print("gtensor : "); print(gtensor); print("\n"); - print("stensor : "); print(stensor); print("\n"); + print("cta_coord : "); print(cta_coord); print("\n"); + print("cta_layout : "); print(cta_layout); print("\n"); + print("gtensor : "); print(gtensor); print("\n"); + print("stensor : "); print(stensor); print("\n"); print("layout_V : "); print(layout_V); print("\n"); print("gtensor_v : "); print(gtensor_v); print("\n"); print("stensor_v : "); print(stensor_v); print("\n"); } #endif - // Restride the cta-into-tma-instr layout - Layout tma_layout_t = composition(make_layout(Int<1>{}, shape_div(size(tma_layout_v), cosize(cta_layout))), cta_layout); - auto tma_layout_tv = make_tile(make_tile(make_layout(tma_layout_t, tma_layout_v), _)); + // Offset inside the TMA-mode for the multicast + auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); + auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); + auto scoord = append(multicast_coord, Int<0>{}); + auto gcoord = append(multicast_coord, Int<0>{}); - // Append with _ until we cover all Rest... modes - auto gtma_layout_tv = append>(tma_layout_tv, _); - auto stma_layout_tv = append>(tma_layout_tv, _); + Tensor gresult = domain_offset(gcoord, gtensor_v); + Tensor sresult = domain_offset(scoord, stensor_v); - // Transform TMA mode - Tensor gtensor_tv = gtensor_v.compose(gtma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...) - Tensor stensor_tv = stensor_v.compose(stma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...) + return cute::make_tuple(gresult, sresult); +} -#if 0 - if (thread0()) { - print("tma_layout_tv : "); print(tma_layout_tv); print("\n"); - print("gtensor_tv : "); print(gtensor_tv); print("\n"); - print("stensor_tv : "); print(stensor_tv); print("\n"); +// TMA Multicast Masks Calculation +template +CUTE_HOST_DEVICE constexpr +auto +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + auto cta_coord_slicer = replace(cta_coord_vmnk, _); + auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_slicer, cta_layout_vmnk); + // Get the instruction code + uint16_t mcast_mask = 0; + for (int i = 0; i < size(cta_layout); ++i) { + mcast_mask |= uint16_t(1) << cta_layout(i); } -#endif - - auto c = make_coord(make_coord(make_coord(cta_coord, _), _)); - auto c_s = append>(c, _); - auto c_g = append>(c, _); - - return cute::make_tuple(group_modes<0,2>(gtensor_tv(c_g)), group_modes<0,2>(stensor_tv(c_s))); + // Shift by the instruction's elected block rank (dynamic) + mcast_mask <<= elected_cta; + return mcast_mask; } } // end namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 674e3519..9e5c93f2 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -715,6 +715,7 @@ print(MMA_Atom> const&) using Atom = MMA_Atom>; print("MMA_Atom\n"); print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" Shape_MNK: "); print(typename Atom::Shape_MNK{}); print("\n"); print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 8c090936..34275831 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -149,17 +149,17 @@ mma_unpack(MMA_Traits const& traits, //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); if constexpr (detail::supports_output_scaling::value) { - detail::explode_with_d_scaling(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}, - traits.accumulate_); + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); } else { detail::explode(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); } } else { @@ -169,19 +169,19 @@ mma_unpack(MMA_Traits const& traits, CUTE_STATIC_ASSERT_V(size(rD) == Int{}); CUTE_STATIC_ASSERT_V(size(rC) == Int{}); if constexpr (detail::supports_output_scaling::value) { - detail::explode_with_d_scaling(MMA_Op::fma, + detail::explode(MMA_Op::fma, rD, make_int_sequence{}, rA, make_int_sequence{}, rB, make_int_sequence{}, rC, make_int_sequence{}, - traits.accumulate_); + &(traits.accumulate_), seq<0>{}); } else { detail::explode(MMA_Op::fma, - rD, make_int_sequence{}, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); } } } @@ -198,7 +198,7 @@ template const& traits, - Tensor && D, + Tensor && D, Tensor const& A, Tensor const& B, Tensor const& C) diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 3bbcd1fb..db6f0fc2 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -208,7 +208,7 @@ make_gmma_desc(Tensor const& tensor) // Start address (4LSB not included) uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); - desc.bitfield.start_address_ = start_address >> 4; + desc.bitfield.start_address_ = static_cast(start_address >> 4); constexpr uint8_t base_offset = 0; desc.bitfield.base_offset_ = base_offset; diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 941f60d7..35d4f8fd 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -91,7 +91,7 @@ // It's harmless to use the macro for other GCC versions or other // compilers, but it has no effect. #if ! defined(CUTE_GCC_UNREACHABLE) -# if defined(__clang__) || defined(__GNUC__) +# if defined(__GNUC__) # define CUTE_GCC_UNREACHABLE __builtin_unreachable() # else # define CUTE_GCC_UNREACHABLE diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 110e233a..f8ca4671 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -325,10 +325,21 @@ CUTE_HOST_DEVICE constexpr auto ceil_div(IntTupleA const& a, IntTupleB const& b) { - if constexpr (is_tuple::value && is_tuple::value) { - static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); - constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 - return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + } else { // tuple int + auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return ceil_div(a, product(b)); } else { return (a + b - Int<1>{}) / b; } diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 71c4ce13..b7517a67 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -418,8 +418,8 @@ make_layout_like(Layout const& layout) // Make a compact layout with the same shape as @a layout // and strides following the order induced by @a layout.stride(), // except mode-0 is always stride-1 and generated column-major. -// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms -// so this generates the 0th mode with LayoutLeft regardless of the reference layout. +// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms so this +// generates the 0th mode with LayoutLeft (preserving stride-0s) regardless of the reference layout template CUTE_HOST_DEVICE constexpr auto @@ -427,7 +427,8 @@ make_fragment_like(Layout const& layout) { constexpr int R = Layout::rank; if constexpr (R > 1 && is_static::value) { - return tiled_product(make_layout(shape<0>(layout)), + return tiled_product(make_layout(get<0>(layout.shape()), + compact_col_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); } else { return make_layout(layout.shape()); @@ -757,7 +758,8 @@ bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, } else if constexpr (is_constant<1, NewShape>::value) { // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); - } else if constexpr (is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { + } else if constexpr (is_static(new_shape))>::value && + is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { // Merge modes because the shapes and strides match return bw_coalesce(old_shape, old_stride, replace_front(new_shape, get(old_shape) * get<0>(new_shape)), @@ -772,6 +774,45 @@ bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, CUTE_GCC_UNREACHABLE; } +// cute::coalesce promises to not change the Layout as a function from integers to codomain. +// It accomplishes this inside of the Layout's domain, but not always outside of the domain. +// Example: (_4,_1):(_1,_0) coalesces to _4:_1. +// detail::coalesce_x preserves the Layout function inside its domain and outside. +// +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i, @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + if constexpr (is_constant<1, decltype(get(flat_shape))>::value) { + return detail::bw_coalesce(flat_shape, flat_stride, Int<2>{}, get(flat_stride)); + } else { + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); + } +} + +// Apply coalesce_x at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return cute::transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce_x(l,t); }); + } else { + return coalesce_x(layout); + } + + CUTE_GCC_UNREACHABLE; +} + } // end namespace detail // "Simplify" the layout by combining modes that are possible to combine @@ -807,6 +848,25 @@ coalesce(Layout const& layout, IntTuple const& trg_profile) CUTE_GCC_UNREACHABLE; } +// Combine static and dynamic modes of a shape. +// @post size(@a result) == size(@a shape) +// @post depth(@a result) <= 1 +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Shape const& shape) +{ + static_assert(is_integral::value || is_tuple::value); + + return cute::fold_first(flatten(shape), [](auto const& init, auto const& a) { + if constexpr (is_static::value == is_static::value) { + return replace_back(init, back(init) * a); // Both static or both dynamic, coalesce and replace + } else { + return append(init, a); // Can't coalesce, so append + } + }); +} + // Replace the modes in layout that have a 0-stride with a 1-size template CUTE_HOST_DEVICE constexpr @@ -918,70 +978,64 @@ template CUTE_HOST_DEVICE constexpr auto -composition_impl(Layout const& lhs, +composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, RShape const& rhs_shape, RStride const& rhs_stride) { if constexpr (is_tuple::value) { // Apply the right-distributivity of Layout composition - return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition_impl(lhs, s, d); }); + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { + return composition_impl(lhs_shape, lhs_stride, s, d); + }); } else if constexpr (is_scaled_basis::value) { // Special case for a ScaledBasis stride - return composition_impl(get(lhs), rhs_shape, rhs_stride.value()); + return composition_impl(basis_get(rhs_stride, lhs_shape), basis_get(rhs_stride, lhs_stride), + rhs_shape, basis_value(rhs_stride)); } else - if constexpr (is_integral::value) { - // Integral Rstride (and RShape) + if constexpr (is_constant<0, RStride>::value) { + // Special case shortcut for any static stride-0 + return Layout{rhs_shape, rhs_stride}; + } else + if constexpr (is_integral::value) { + // Special case shortcut for any integral LShape + return Layout{rhs_shape, rhs_stride * lhs_stride}; + } else + if constexpr (is_constant<1, RStride>::value) { + // Special case shortcut for any static stride-1 + constexpr int R = rank_v; + auto result_shape_0 = take<0,R-1>(lhs_shape); - // NOTE: Should only flatten once for efficiency - auto flat_shape = flatten(lhs.shape()); - [[maybe_unused]] auto flat_stride = flatten(lhs.stride()); - [[maybe_unused]] constexpr int R = rank(flat_shape); + // Mod out the rhs_shape from the lhs_shape + auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); - if constexpr (is_constant<0, RStride>::value) { - // Special case shortcut for any static stride-0 - return Layout{rhs_shape, rhs_stride}; - } else - if constexpr (is_integral::value) { - // Special case shortcut for any integral LShape - auto result_stride = rhs_stride * flat_stride; - return Layout{rhs_shape, result_stride}; - } else - if constexpr (is_constant<1, RStride>::value) { - // Special case shortcut for any static stride-1 - auto result_shape_0 = take<0,R-1>(flat_shape); + // Jump into coalesce and append (rest_shape, get(lhs_stride)) + return detail::bw_coalesce(result_shape_1, lhs_stride, rest_shape, get(lhs_stride)); + } else { + // General case: integral RShape and RStride, tuple LShape and LStride + constexpr int R = rank_v; + auto result_shape_0 = take<0,R-1>(lhs_shape); + auto result_stride_0 = take<0,R-1>(lhs_stride); - // Mod out the rhs_shape from the lhs.shape() - auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), - [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); - }); + // Divide out the rhs_stride from the lhs_shape + auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), + [] (auto const& init, auto const& di) { + return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); + }); - // Jump into coalesce and append (rest_shape, get(lhs.stride()) - return detail::bw_coalesce(result_shape_1, flat_stride, rest_shape, get(flat_stride)); - } else - { - // General case - auto result_shape_0 = take<0,R-1>(flat_shape); - auto result_stride_0 = take<0,R-1>(flat_stride); + // Apply any lhs_shape changes to the stride + auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); - // Divide out the rhs_stride from the lhs.shape() - auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), - [] (auto const& init, auto const& di) { - return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); - }); + // Mod out the rhs_shape from the lhs_shape + auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); - // Apply any lhs.shape() changes to the stride - auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); - - // Mod out the rhs_shape from the lhs.shape() - auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), - [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); - }); - - // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) - return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(flat_stride)); - } + // Jump into coalesce and append (rest_shape, rest_stride * get(lhs_stride)) + return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(lhs_stride)); } CUTE_GCC_UNREACHABLE; @@ -996,7 +1050,9 @@ auto composition(Layout const& lhs, Layout const& rhs) { - return detail::composition_impl(lhs, rhs.shape(), rhs.stride()); + auto coprofile = repeat_like(decltype(coshape(rhs)){}, Int<0>{}); + auto flat_lhs = detail::coalesce_x(lhs, coprofile); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs.shape(), rhs.stride()); } template @@ -1012,7 +1068,8 @@ composition(Layout const& lhs, } else if constexpr (is_underscore::value) { return lhs; } else if constexpr (is_integral::value) { - return detail::composition_impl(lhs, rhs, Int<1>{}); + auto flat_lhs = detail::coalesce_x(lhs); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs, Int<1>{}); } CUTE_GCC_UNREACHABLE; @@ -1032,14 +1089,14 @@ composition(Layout const& lhs, namespace detail { // @pre @a layout has been filtered (flattened and no stride-0 or size-1 modes). -template +template CUTE_HOST_DEVICE constexpr auto -complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) +complement(Shape const& shape, Stride const& stride, CoTarget const& cotarget) { if constexpr (is_constant<0, Stride>::value) { // Special case for irreducible rank-1 stride-0 layout - return make_layout(cosize_hi); + return make_layout(coalesce(cotarget)); } else { // General case constexpr int R = rank_v; @@ -1055,28 +1112,30 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) { auto [shape, stride, result_shape, result_stride] = init; auto min_stride = cute::min(stride); - auto min_idx = find(stride, min_stride); + auto min_idx = cute::find(stride, min_stride); auto new_shape = min_stride / get(result_stride); - auto new_stride = get(shape) * min_stride; + auto new_stride = min_stride * get(shape); static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); return cute::make_tuple(remove(shape), // Remove the min_idx from shape remove(stride), // Remove the min_idx from stride append(result_shape , new_shape ), // new shape = min_stride / last_stride - append(result_stride, new_stride)); // new stride = curr_shape * min_stride + append(result_stride, new_stride)); // new stride = min_stride * curr_shape }); // Append the last shape mode - auto new_shape = get<0>(stride_) / get(result_stride); + auto new_shape = get<0>(stride_) / get(result_stride); // new shape = min_stride / last_stride static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); - auto result_shape = append(result_shape_, new_shape); // new shape = min_stride / last_stride + auto result_shape = append(result_shape_, new_shape); // Compute the rest_shape and rest_stride - auto rest_stride = get<0>(shape_) * get<0>(stride_); - auto rest_shape = ceil_div(cosize_hi, rest_stride); + auto new_stride = get<0>(stride_) * get<0>(shape_); // new stride = min_stride * curr_shape + auto rest_shape = coalesce(ceil_div(cotarget, new_stride)); + auto rest_stride = compact_col_major(rest_shape, new_stride); - // Jump into coalesce and append (rest_shape, rest_stride) - return detail::bw_coalesce(result_shape, result_stride, rest_shape, rest_stride); + // Coalesce and append (rest_shape, rest_stride) + return coalesce(make_layout(make_shape (result_shape , rest_shape ), + make_stride(result_stride, rest_stride))); } CUTE_GCC_UNREACHABLE; @@ -1084,14 +1143,13 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) } // end namespace detail -template +template CUTE_HOST_DEVICE constexpr auto -complement(Layout const& layout, CoSizeHi const& cosize_hi) +complement(Layout const& layout, CoTarget const& cotarget) { - static_assert(cute::is_integral::value, "Expected integral codomain size in complement."); auto filter_layout = filter(layout); - return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize_hi); + return detail::complement(filter_layout.shape(), filter_layout.stride(), shape(cotarget)); } template @@ -1365,7 +1423,7 @@ auto logical_divide(Layout const& layout, Layout const& tiler) { - return composition(layout, make_layout(tiler, complement(tiler, size(layout)))); + return composition(layout, make_layout(tiler, complement(tiler, shape(layout)))); } template diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index 93c60898..3dbd2cd9 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -392,12 +392,12 @@ composition(Layout const& a, // complement // -template +template CUTE_HOST_DEVICE constexpr auto -complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) +complement(ComposedLayout const& layout, CoTarget const& cotarget) { - return complement(layout.layout_b(), cosize_hi); + return complement(layout.layout_b(), cotarget); } template @@ -610,7 +610,7 @@ recast_layout(ComposedLayout const& layout) else if constexpr (scale::num == 1) { return downcast(layout); } - else if constexpr (scale::den == 1) { + else if constexpr (scale::den == 1) { return upcast(layout); } else { diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 27d1cf8e..651ff8e8 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -73,25 +73,17 @@ make_arithmetic_tuple(T const&... t) { return ArithmeticTuple(t...); } -template +template CUTE_HOST_DEVICE constexpr auto -as_arithmetic_tuple(tuple const& t) { - return ArithmeticTuple(t); -} - -template ::value)> -CUTE_HOST_DEVICE constexpr -T const& as_arithmetic_tuple(T const& t) { - return t; -} - -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ArithmeticTuple const& t) { - return t; + if constexpr (is_tuple::value) { + return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); }, + [](auto const&... a){ return make_arithmetic_tuple(a...); }, + tuple_seq{}); + } else { + return t; + } } // @@ -289,6 +281,26 @@ basis_get(SB const& e, Tuple const& t) namespace detail { +template +CUTE_HOST_DEVICE constexpr +auto +to_atuple_i(T const& t, seq) { + return make_arithmetic_tuple((void(I),Int<0>{})..., t); +} + +} // end namespace detail + +// Turn a ScaledBases into a rank-N+1 ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq{}); +} + +namespace detail { + template struct Basis; @@ -315,71 +327,6 @@ struct Basis { template using E = typename detail::Basis::type; -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(T const& t, seq, seq) { - return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...); -} - -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ArithmeticTuple const& t, seq, seq) { - return make_arithmetic_tuple(get(t)..., (void(J),Int<0>{})...); -} - -} // end namespace detail - -// Turn a ScaledBases into a rank-M ArithmeticTuple -// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0) -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ScaledBasis const& t) { - static_assert(M > N, "Mismatched ranks"); - return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); -} - -// Turn a ScaledBases into a rank-N ArithmeticTuple -// with N prefix 0s: (_0,_0,...N...,_0,T) -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ScaledBasis const& t) { - return as_arithmetic_tuple(t); -} - -// Turn an ArithmeticTuple into a rank-M ArithmeticTuple -// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ArithmeticTuple const& t) { - static_assert(M >= sizeof...(T), "Mismatched ranks"); - return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); -} - -template -CUTE_HOST_DEVICE constexpr -auto -safe_div(ScaledBasis const& b, U const& u) -{ - auto t = safe_div(b.value(), u); - return ScaledBasis{t}; -} - -template -CUTE_HOST_DEVICE constexpr -auto -shape_div(ScaledBasis const& b, U const& u) -{ - auto t = shape_div(b.value(), u); - return ScaledBasis{t}; -} - template CUTE_HOST_DEVICE constexpr auto @@ -387,8 +334,7 @@ make_basis_like(Shape const& shape) { if constexpr (is_integral::value) { return Int<1>{}; - } - else { + } else { // Generate bases for each rank of shape return transform(tuple_seq{}, shape, [](auto I, auto si) { // Generate bases for each rank of si and add an i on front @@ -408,6 +354,28 @@ make_basis_like(Shape const& shape) CUTE_GCC_UNREACHABLE; } +// +// Arithmetic +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(ScaledBasis const& b, U const& u) +{ + auto t = safe_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(ScaledBasis const& b, U const& u) +{ + auto t = shape_div(b.value(), u); + return ScaledBasis{t}; +} + // Equality template CUTE_HOST_DEVICE constexpr @@ -432,7 +400,7 @@ operator==(T const&, ScaledBasis const&) { } // Abs -template +template CUTE_HOST_DEVICE constexpr auto abs(ScaledBasis const& e) { @@ -440,7 +408,7 @@ abs(ScaledBasis const& e) { } // Multiplication -template +template CUTE_HOST_DEVICE constexpr auto operator*(A const& a, ScaledBasis const& e) { @@ -448,7 +416,7 @@ operator*(A const& a, ScaledBasis const& e) { return ScaledBasis{r}; } -template +template CUTE_HOST_DEVICE constexpr auto operator*(ScaledBasis const& e, B const& b) { @@ -457,44 +425,25 @@ operator*(ScaledBasis const& e, B const& b) { } // Addition -template -CUTE_HOST_DEVICE constexpr -auto -operator+(ScaledBasis const& t, ArithmeticTuple const& u) { - constexpr int R = cute::max(N+1, int(sizeof...(U))); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -} - -template -CUTE_HOST_DEVICE constexpr -auto -operator+(ArithmeticTuple const& t, ScaledBasis const& u) { - constexpr int R = cute::max(int(sizeof...(T)), M+1); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -} - -template -CUTE_HOST_DEVICE constexpr -auto -operator+(ScaledBasis const& t, tuple const& u) { - constexpr int R = cute::max(N+1, int(sizeof...(U))); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -} - -template -CUTE_HOST_DEVICE constexpr -auto -operator+(tuple const& t, ScaledBasis const& u) { - constexpr int R = cute::max(int(sizeof...(T)), M+1); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -} - -template +template CUTE_HOST_DEVICE constexpr auto operator+(ScaledBasis const& t, ScaledBasis const& u) { - constexpr int R = cute::max(N+1,M+1); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { + return as_arithmetic_tuple(t) + u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { + return t + as_arithmetic_tuple(u); } template diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp index 8cc36253..5113719d 100644 --- a/include/cute/numeric/complex.hpp +++ b/include/cute/numeric/complex.hpp @@ -56,10 +56,10 @@ fma(complex & d, complex const& b, complex const& c) { - d.real(fma( a.real(), b.real(), c.real())); - d.imag(fma( a.real(), b.imag(), c.imag())); - d.real(fma(-a.imag(), b.imag(), d.real())); - d.imag(fma( a.imag(), b.real(), d.imag())); + fma(d.real(), a.real(), b.real(), c.real()); + fma(d.imag(), a.real(), b.imag(), c.imag()); + fma(d.real(), -a.imag(), b.imag(), d.real()); + fma(d.imag(), a.imag(), b.real(), d.imag()); } /// Fused multiply-add for triplets diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 5647f97c..604477a0 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -41,7 +41,6 @@ #include #include -#include namespace cute { @@ -102,6 +101,8 @@ template // Found the gmem struct is_gmem> : true_type {}; template // Recurse on ::iterator, if possible struct is_gmem> : is_gmem {}; +template +constexpr bool is_gmem_v = is_gmem

::value; // Idempotent gmem tag on an iterator template @@ -163,6 +164,8 @@ template // Found the smem struct is_smem> : true_type {}; template // Recurse on ::iterator, if possible struct is_smem> : is_smem {}; +template +constexpr bool is_smem_v = is_smem

::value; // Idempotent smem tag on an iterator template @@ -224,6 +227,8 @@ template struct is_rmem : bool_constant::value || is_smem::value)> {}; template struct is_rmem> : true_type {}; +template +constexpr bool is_rmem_v = is_rmem

::value; // Idempotent rmem tag on an iterator template diff --git a/include/cute/pointer_flagged.hpp b/include/cute/pointer_flagged.hpp index aa917d9f..08751eb1 100644 --- a/include/cute/pointer_flagged.hpp +++ b/include/cute/pointer_flagged.hpp @@ -89,7 +89,7 @@ downcast(ComposedLayout,Layout> const& layout) // Conversion with swizzle_layout // -template +template CUTE_HOST_DEVICE auto as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) @@ -129,6 +129,14 @@ as_position_independent_swizzle_tensor(Tensor&& tensor) // // Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_layout(ComposedLayout,Layout> const& layout) +{ + print_layout(as_position_independent_swizzle_layout(layout)); +} + template CUTE_HOST_DEVICE void diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index 28d3ee67..71ace9a8 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -316,6 +316,8 @@ template struct is_tensor : false_type {}; template struct is_tensor> : true_type {}; +template +constexpr bool is_tensor_v = is_tensor::value; // Customization point for creation of owning and non-owning Tensors template @@ -1082,7 +1084,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const #include #include - // // Tensor Algorithms // diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index a8cab903..56cc814e 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -101,6 +101,9 @@ using CUTE_STL_NAMESPACE::is_lvalue_reference_v; using CUTE_STL_NAMESPACE::is_reference; using CUTE_STL_NAMESPACE::is_trivially_copyable; +using CUTE_STL_NAMESPACE::is_convertible; +using CUTE_STL_NAMESPACE::is_convertible_v; + using CUTE_STL_NAMESPACE::is_same; using CUTE_STL_NAMESPACE::is_same_v; @@ -247,4 +250,15 @@ is_valid(F&&, Args&&...) { return detail::is_valid_impl(int{}); } +template class True, template class False> +struct conditional_template { + template + using type = True; +}; + +template class True, template class False> +struct conditional_template { + template + using type = False; +}; } // end namespace cute diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 48ad7cca..dcaa1093 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -33,16 +33,6 @@ and is safe to use in a union. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" #include "cutlass/functional.h" @@ -57,7 +47,7 @@ template < int N, bool RegisterSized = sizeof_bits::value >= 32 > -class Array; +struct Array; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -90,8 +80,7 @@ template < typename T, int N > -class Array { -public: +struct Array { /// Storage type using Storage = T; @@ -101,10 +90,10 @@ public: /// Number of storage elements //static std::size_t const kStorageElements = N; - static size_t const kStorageElements = N; + static constexpr size_t kStorageElements = N; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; // // C++ standard members @@ -346,26 +335,9 @@ public: } }; -private: - /// Internal storage Storage storage[kElements]; -public: - - #if 0 - CUTLASS_HOST_DEVICE - Array() { } - - CUTLASS_HOST_DEVICE - Array(Array const &x) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElements; ++i) { - storage[i] = x.storage[i]; - } - } - #endif - /// Efficient clear method CUTLASS_HOST_DEVICE void clear() { @@ -530,39 +502,25 @@ public: template CUTLASS_HOST_DEVICE Array make_Array(Element x) { - Array m; - m[0] = x; - return m; + return {x}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y) { - Array m; - m[0] = x; - m[1] = y; - return m; + return {x,y}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y, Element z) { - Array m; - m[0] = x; - m[1] = y; - m[2] = z; - return m; + return {x,y,z}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y, Element z, Element w) { - Array m; - m[0] = x; - m[1] = y; - m[2] = z; - m[3] = w; - return m; + return {x,y,z,w}; } @@ -1104,6 +1062,58 @@ struct square_and_plus> { } }; +/// Inverse-square-root +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + Array result; + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + return result; + } +}; + +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & a) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = h2rsqrt(a_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half d_residual = hrsqrt(a_residual_ptr[N - 1]); + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + + #endif + + return result; + } +}; + /// Fused multiply-add-relu0 template struct multiply_add_relu0, Array, Array> { @@ -2513,7 +2523,6 @@ struct bit_xor> { } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// // Operator overloads ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -2525,6 +2534,20 @@ Array operator+(Array const &lhs, Array const &rhs) { return op(lhs, rhs); } +template +CUTLASS_HOST_DEVICE +Array operator+(T const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, T const &rhs) { + plus> op; + return op(lhs, rhs); +} + template CUTLASS_HOST_DEVICE Array operator-(Array const &lhs, Array const &rhs) { diff --git a/include/cutlass/array_planar_complex.h b/include/cutlass/array_planar_complex.h index 9fcd4d18..2dd8aa84 100644 --- a/include/cutlass/array_planar_complex.h +++ b/include/cutlass/array_planar_complex.h @@ -51,13 +51,12 @@ struct ArrayPlanarComplex { using Element = Element_; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; /// Underlying Fragment of real-valued elemenets - using ArrayReal = Array; + using ArrayReal = cutlass::Array; public: - /// Fragment of real-valued elements representing the real part ArrayReal real; @@ -65,19 +64,6 @@ public: ArrayReal imag; public: - - /// Ctor - CUTLASS_HOST_DEVICE - ArrayPlanarComplex() { } - - /// Ctor - CUTLASS_HOST_DEVICE - ArrayPlanarComplex( - ArrayReal const &real_, - ArrayReal const &imag_ - ): - real(real_), imag(imag_) { } - /// Sets the array to zero efficiently CUTLASS_HOST_DEVICE void clear() { @@ -93,7 +79,7 @@ template CUTLASS_HOST_DEVICE ArrayPlanarComplex make_ArrayPlanarComplex(Array const &real, Array const &imag) { - return ArrayPlanarComplex(real, imag); + return ArrayPlanarComplex{real, imag}; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index 25bbe355..eb77a931 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -32,15 +32,6 @@ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe to use in a union. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once @@ -57,10 +48,8 @@ template < typename T, int N > -class Array { -public: - - static int const kSizeBits = sizeof_bits::value * N; +struct Array { + static constexpr int kSizeBits = sizeof_bits::value * N; /// Storage type using Storage = typename platform::conditional< @@ -77,16 +66,16 @@ public: using Element = T; /// Number of logical elements per stored object - static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; + static constexpr int kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; /// Number of storage elements - static size_t const kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; + static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; /// Bitmask for covering one item - static Storage const kMask = ((Storage(1) << sizeof_bits::value) - 1); + static constexpr Storage kMask = ((Storage(1) << sizeof_bits::value) - 1); // // C++ standard members with pointer types removed @@ -105,16 +94,14 @@ public: /// Reference object inserts or extracts sub-byte items class reference { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - /// Default ctor - CUTLASS_HOST_DEVICE - reference(): ptr_(nullptr), idx_(0) { } + reference() = default; /// Ctor CUTLASS_HOST_DEVICE @@ -123,11 +110,38 @@ public: /// Assignment CUTLASS_HOST_DEVICE reference &operator=(T x) { + // `*ptr_ & kUpdateMask` will read ptr_ before write to it + // This means code pattern like + // + // ```cpp + // Array result; + // result[0] = xxx; + // ``` + // + // Will leads to compiler warning on use of unintialized member variable. Although we know + // this read of uninitialized member variable is harmeless. + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wuninitialized" +#elif defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuninitialized" +# pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + Storage item = (reinterpret_cast(x) & kMask); Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); + *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); +#if defined(__clang__) +# pragma clang diagnostic pop +#elif defined(__GNUC__) +# pragma GCC diagnostic pop +#endif + return *this; } @@ -160,16 +174,14 @@ public: class const_reference { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - /// Default ctor - CUTLASS_HOST_DEVICE - const_reference(): ptr_(nullptr), idx_(0) { } + const_reference() = default; /// Ctor CUTLASS_HOST_DEVICE @@ -209,15 +221,14 @@ public: class iterator { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - iterator(): ptr_(nullptr), idx_(0) { } + iterator() = default; CUTLASS_HOST_DEVICE iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -288,15 +299,14 @@ public: class const_iterator { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - const_iterator(): ptr_(nullptr), idx_(0) { } + const_iterator() = default; CUTLASS_HOST_DEVICE const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -367,15 +377,14 @@ public: class reverse_iterator { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - reverse_iterator(): ptr_(nullptr), idx_(0) { } + reverse_iterator() = default; CUTLASS_HOST_DEVICE reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -385,40 +394,19 @@ public: class const_reverse_iterator { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - const_reverse_iterator(): ptr_(nullptr), idx_(0) { } + const_reverse_iterator() = default; CUTLASS_HOST_DEVICE const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } }; -private: - - /// Internal storage - Storage storage[kStorageElements] = {Storage{0}}; - -public: - - #if 0 - CUTLASS_HOST_DEVICE - Array() { } - - CUTLASS_HOST_DEVICE - Array(Array const &x) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(kStorageElements); ++i) { - storage[i] = x.storage[i]; - } - } - #endif - /// Efficient clear method CUTLASS_HOST_DEVICE void clear() { @@ -489,7 +477,6 @@ public: return storage; } - CUTLASS_HOST_DEVICE constexpr bool empty() const { return !kElements; @@ -560,10 +547,9 @@ public: return const_reverse_iterator(storage); } - // - // Comparison operators - // - +private: + /// Internal storage + Storage storage[kStorageElements]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index 75cadbfa..c2e6cb0d 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -34,16 +34,6 @@ 8 bits of exponent and 7 bit of mantissa. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 0996fff8..36f603e3 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -229,7 +229,7 @@ struct ClusterLaunchParams { /// void const* kernel_ptr = /// const_cast(reinterpret_cast( /// &kernel)); -/// auto status = launch_on_cluster( +/// auto status = launch_kernel_on_cluster( /// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)}, /// kernel_ptr, x, y, z); /// @endcode @@ -243,10 +243,10 @@ launch_kernel_on_cluster(const ClusterLaunchParams& params, // the parameters as an array of raw pointers. if constexpr (sizeof...(Args) == 0) { return cutlass::ClusterLauncher::launch( - params.grid_dims, - params.cluster_dims, + params.grid_dims, + params.cluster_dims, params.block_dims, - params.smem_size_in_bytes, + params.smem_size_in_bytes, params.cuda_stream, kernel_ptr, nullptr); } @@ -255,12 +255,12 @@ launch_kernel_on_cluster(const ClusterLaunchParams& params, detail::checked_addressof(std::forward(args))... }; return cutlass::ClusterLauncher::launch( - params.grid_dims, - params.cluster_dims, + params.grid_dims, + params.cluster_dims, params.block_dims, - params.smem_size_in_bytes, + params.smem_size_in_bytes, params.cuda_stream, - kernel_ptr, + kernel_ptr, kernel_params); } } diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 32cfa5f7..1f92b667 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -28,15 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ #pragma once diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index ec864219..d2e89529 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -44,15 +44,6 @@ Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once @@ -247,9 +238,10 @@ public: /// Returns filter extent as Tensor4DCoord CUTLASS_HOST_DEVICE - cutlass::Tensor4DCoord filter_extent() const { + cutlass::Tensor4DCoord filter_extent(bool is_deconv = false) const { - return cutlass::Tensor4DCoord ({K, R, S, C / groups}); + return is_deconv ? cutlass::Tensor4DCoord ({C, R, S, K / groups}) + : cutlass::Tensor4DCoord ({K, R, S, C / groups}); } /// Returns output extent as Tensor4DCoord @@ -340,6 +332,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size( problem_size.K, problem_size.R * problem_size.S * problem_size.C / problem_size.groups ); + case Operator::kDeconv: case Operator::kDgrad: return gemm::GemmCoord( problem_size.N * problem_size.H * problem_size.W, @@ -404,6 +397,7 @@ int implicit_gemm_k_iterations( iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); break; + case Operator::kDeconv: case Operator::kDgrad: elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); @@ -505,6 +499,7 @@ int implicit_gemm_k_iterations_per_channel( iterations = problem_size.R * problem_size.S; break; + case Operator::kDeconv: case Operator::kDgrad: iterations = problem_size.R * problem_size.S; break; @@ -526,6 +521,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); default : break; @@ -540,6 +536,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); default : break; @@ -554,6 +551,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); default : break; @@ -568,6 +566,7 @@ int64_t implicit_gemm_tensor_a_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); default : break; @@ -582,6 +581,7 @@ int64_t implicit_gemm_tensor_b_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); default : break; @@ -596,6 +596,7 @@ int64_t implicit_gemm_tensor_c_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); default : break; diff --git a/include/cutlass/conv/conv3d_problem_size.h b/include/cutlass/conv/conv3d_problem_size.h index 56b16423..9a9514f2 100644 --- a/include/cutlass/conv/conv3d_problem_size.h +++ b/include/cutlass/conv/conv3d_problem_size.h @@ -44,15 +44,6 @@ Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once @@ -277,9 +268,10 @@ public: /// Returns filter extent as Tensor5DCoord CUTLASS_HOST_DEVICE - cutlass::Tensor5DCoord filter_extent() const { + cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const { - return cutlass::Tensor5DCoord ({K, T, R, S, C}); + return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K}) + : cutlass::Tensor5DCoord ({K, T, R, S, C}); } /// Returns output extent as Tensor5DCoord @@ -351,6 +343,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size( problem_size.K, problem_size.T * problem_size.R * problem_size.S * problem_size.C ); + case Operator::kDeconv: case Operator::kDgrad: return gemm::GemmCoord( problem_size.N * problem_size.D * problem_size.H * problem_size.W, @@ -387,7 +380,8 @@ int implicit_gemm_k_iterations( elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); break; - + + case Operator::kDeconv: case Operator::kDgrad: elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); @@ -430,6 +424,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); default : break; @@ -444,6 +439,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); default : break; @@ -458,6 +454,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); default : break; @@ -472,6 +469,7 @@ int64_t implicit_gemm_tensor_a_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); default : break; @@ -486,6 +484,7 @@ int64_t implicit_gemm_tensor_b_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); default : break; @@ -500,6 +499,7 @@ int64_t implicit_gemm_tensor_c_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); default : break; diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index a61f573e..243ee269 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -70,16 +70,6 @@ Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" @@ -98,7 +88,8 @@ namespace conv { enum class Operator { kFprop, kDgrad, - kWgrad + kWgrad, + kDeconv }; /// Distinguishes convolution from cross correlation diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 69cfbaba..603c47e8 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -144,6 +144,11 @@ private: public: + /// Access the Params structure + Params const& params() const { + return params_; + } + /// Determines whether the conv can execute the given problem. static Status can_implement(Arguments const& args) { @@ -323,13 +328,12 @@ public: } } else { - CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; - - launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params); - + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + } } } else { diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index d2319f2c..62c7e871 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -153,7 +153,7 @@ public: if (kConvolutionalOperator == conv::Operator::kFprop) { if (args.problem_size.K % kAlignmentC) return Status::kErrorMisalignedOperand; - } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + } else if (kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) { if (args.problem_size.C % kAlignmentC) return Status::kErrorMisalignedOperand; } else if (kConvolutionalOperator == conv::Operator::kWgrad) { @@ -161,16 +161,16 @@ public: return Status::kErrorMisalignedOperand; } - // check for unsupported problem sizes for strided dgrad implementation - if (kConvolutionalOperator == conv::Operator::kDgrad && + // check for unsupported problem sizes for strided dgrad / deconv implementation + if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) && kStrideSupport == conv::StrideSupport::kStrided) { - // split-k (serial or parallel) is not supported for strided dgrad + // split-k (serial or parallel) is not supported for strided dgrad / deconv if(args.problem_size.split_k_slices > 1) { return Status::kErrorNotSupported; } - - // dilation > {1x1} is not supported for strided dgrad + + // dilation > {1x1} is not supported for strided dgrad / deconv if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { return Status::kErrorNotSupported; } diff --git a/include/cutlass/conv/kernel/default_conv2d.h b/include/cutlass/conv/kernel/default_conv2d.h index 51310304..f629bbb2 100644 --- a/include/cutlass/conv/kernel/default_conv2d.h +++ b/include/cutlass/conv/kernel/default_conv2d.h @@ -128,6 +128,28 @@ struct DefaultConvEpilogueWithBroadcastSimt { >::Epilogue; }; +template < + typename ArchTag, + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastSimtStridedDgrad { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimtStridedDgrad< + Shape, + WarpMmaSimt, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + template < typename ArchTag, typename Shape, diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index 1c7f3444..9fbd97e5 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -76,7 +76,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -327,7 +327,6 @@ struct DefaultConv2dFprop < >; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage @@ -1167,7 +1166,11 @@ struct DefaultConv2dFprop < WarpMmaTensorOp, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1628,7 +1631,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1741,7 +1748,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1751,7 +1762,6 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, conv::Operator::kFprop >; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1853,7 +1863,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1967,7 +1981,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h index 4100c8dd..8589ace0 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -76,7 +76,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv2dFpropFusion; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h index b0e0ae65..76bc1288 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h @@ -69,7 +69,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index 6e6127d7..0825789c 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -31,7 +31,7 @@ /*! \file \brief - Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. */ @@ -71,7 +71,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -143,6 +143,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB @@ -164,7 +165,7 @@ struct DefaultConv2dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport, AlignmentA, AlignmentB @@ -184,7 +185,7 @@ struct DefaultConv2dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport, AlignmentA, AlignmentB diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h index 978d49c9..e6e8a822 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -72,7 +72,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements diff --git a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h index 927f70ce..e2deaf6f 100644 --- a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -77,7 +77,7 @@ template < typename MathOperatorTag, conv::GroupMode GroupMode, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop.h b/include/cutlass/conv/kernel/default_conv3d_fprop.h index 3ea1e11c..41fdd64a 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop.h @@ -73,7 +73,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv3dFprop; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -94,7 +94,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -113,7 +114,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -202,7 +204,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -221,7 +224,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -306,7 +310,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -325,7 +330,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -416,7 +422,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -435,7 +442,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -492,7 +500,11 @@ struct DefaultConv3dFprop < WarpMmaTensorOp, 1, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -526,7 +538,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -545,7 +558,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -598,7 +612,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -632,7 +650,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -651,7 +670,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -706,7 +726,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -738,7 +762,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -757,7 +782,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -813,7 +839,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -845,7 +875,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -864,7 +895,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -922,7 +954,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -933,10 +969,10 @@ struct DefaultConv3dFprop < conv::Operator::kFprop, Conv3dProblemSize >; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace kernel } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h index 34497715..d0457d57 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -77,7 +77,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv3dFpropFusion; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h index 1d70a29e..38e4de5c 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h @@ -31,7 +31,7 @@ /*! \file \brief - Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. */ @@ -71,7 +71,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -142,6 +142,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB @@ -163,7 +164,7 @@ struct DefaultConv3dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport, AlignmentA, AlignmentB @@ -183,7 +184,7 @@ struct DefaultConv3dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport >::Kernel; diff --git a/include/cutlass/conv/kernel/default_deconv2d.h b/include/cutlass/conv/kernel/default_deconv2d.h new file mode 100644 index 00000000..ace21b92 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv2d.h @@ -0,0 +1,983 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDeconv2d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h b/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h new file mode 100644 index 00000000..d11432ed --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv2d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv2dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_deconv3d.h b/include/cutlass/conv/kernel/default_deconv3d.h new file mode 100644 index 00000000..e9eb4cc5 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv3d.h @@ -0,0 +1,525 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultDeconv3d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h new file mode 100644 index 00000000..5c50c766 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h @@ -0,0 +1,303 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv3d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv3dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_depthwise_fprop.h b/include/cutlass/conv/kernel/default_depthwise_fprop.h index cbe84b1e..aa4f2c35 100644 --- a/include/cutlass/conv/kernel/default_depthwise_fprop.h +++ b/include/cutlass/conv/kernel/default_depthwise_fprop.h @@ -80,7 +80,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -109,7 +109,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, // MatrixShape typename StrideShape = cutlass::MatrixShape<-1, -1>, // MatrixShape< Height, Width> diff --git a/include/cutlass/conv/kernel/direct_convolution.h b/include/cutlass/conv/kernel/direct_convolution.h index a3468bda..5e429956 100644 --- a/include/cutlass/conv/kernel/direct_convolution.h +++ b/include/cutlass/conv/kernel/direct_convolution.h @@ -155,7 +155,7 @@ struct DirectConvolutionParams { swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); // Dynamic SMEM usage because stride and dilation are runtime params. - smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size); + smem_size_ = (max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index f8cee810..c4de265e 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -61,7 +61,7 @@ template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode > @@ -233,9 +233,9 @@ struct ImplicitGemmConvolution { ptr_A(args.ref_A.data()), iterator_B(args.problem_size, args.ref_B.layout()), ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), ptr_D(args.ref_D.data()), output_op(args.output_op), semaphore(semaphore), @@ -397,7 +397,6 @@ struct ImplicitGemmConvolution { threadblock_offset ); - // Construct the epilogue Epilogue epilogue( shared_storage.epilogue, diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index ded39ffa..c768a296 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -61,7 +61,7 @@ template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem > struct ImplicitGemmConvolutionWithFusedEpilogue { diff --git a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp index c6996f15..43c6d595 100644 --- a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -81,7 +81,6 @@ public: using MainloopParams = typename CollectiveMainloop::Params; static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; static_assert(ArchTag::kMinComputeCapability >= 90); - // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index 968c91e2..1725db5a 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -67,7 +67,8 @@ template < typename Layout_, typename ThreadMap_, typename AccessType_ = cutlass::AlignedArray, - conv::GroupMode GroupMode_ = conv::GroupMode::kNone + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, + bool IsDeconv_ = false > class Conv2dFpropFilterTileAccessIteratorAnalytic { public: @@ -85,6 +86,7 @@ public: using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; @@ -95,7 +97,7 @@ public: static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), "Vectors implied by the thread map must be divisible by the access type."); - + // // Simplifying assertions // @@ -152,13 +154,16 @@ public: filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + if (kGroupMode != conv::GroupMode::kNone) { filter_c_init_ = filter_c_; if (kGroupMode == conv::GroupMode::kDepthwise){ channels_per_group_ = 1; crs_per_group_ = problem_size_.S * problem_size_.R; } else { - channels_per_group_ = problem_size_.C / problem_size_.groups; + channels_per_group_ = input_channels / problem_size_.groups; crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); } } @@ -167,7 +172,7 @@ public: for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { - group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups); + group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (output_channels / problem_size_.groups); } } @@ -241,12 +246,15 @@ public: TensorCoord coord = at(); + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + if (kGroupMode == conv::GroupMode::kNone) { - return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + return coord.n() < output_channels && coord.c() < input_channels; } else if (kGroupMode == conv::GroupMode::kDepthwise) { - return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE. + return coord.n() < output_channels && coord.c() < 1; // channels_per_group_ is always equal to ONE. } else { - return coord.n() < problem_size_.K && coord.c() < channels_per_group_ && + return coord.n() < output_channels && coord.c() < channels_per_group_ && group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; } } @@ -289,19 +297,22 @@ public: CUTLASS_HOST_DEVICE static Status can_implement(Conv2dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + // check alignment constraint on iterator's contiguous dimension - if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + if ((input_channels / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } if (platform::is_same>::value) { - if (problem_size.K % 32) { + if (output_channels % 32) { return Status::kErrorInvalidProblem; } } if (platform::is_same>::value) { - if (problem_size.K % 64) { + if (output_channels % 64) { return Status::kErrorInvalidProblem; } } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 3fc640b1..4c2343c3 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -67,7 +67,8 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray + typename AccessType_ = cutlass::AlignedArray, + bool IsDeconv_ = false > class Conv2dFpropFilterTileAccessIteratorOptimized{ public: @@ -85,6 +86,7 @@ public: using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; @@ -176,11 +178,11 @@ public: filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); Index column = threadblock_offset.column() + thread_coord.strided(); - channels_per_group_ = problem_size_.C / problem_size_.groups; + channels_per_group_ = (IsDeconv ? problem_size_.K : problem_size_.C) / problem_size_.groups; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { @@ -287,19 +289,22 @@ public: CUTLASS_HOST_DEVICE static Status can_implement(Conv2dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + // check alignment constraint on iterator's contiguous dimension - if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + if ((input_channels / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } if (platform::is_same>::value) { - if (problem_size.K % 32) { + if (output_channels % 32) { return Status::kErrorInvalidProblem; } } if (platform::is_same>::value) { - if (problem_size.K % 64) { + if (output_channels % 64) { return Status::kErrorInvalidProblem; } } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index 5ef1ab5f..85dd37ff 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -64,7 +64,8 @@ namespace threadblock { template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + bool IsDeconv_ = false > class Conv3dFpropFilterTileAccessIteratorAnalytic { public: @@ -82,6 +83,7 @@ public: using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; @@ -198,8 +200,11 @@ public: TensorCoord coord = at(); - return coord.n() < problem_size_.K && - coord.c() < problem_size_.C; + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + return coord.n() < output_channels && + coord.c() < input_channels; } /// Returns a pointer to the vector starting at the current coordinate @@ -234,8 +239,10 @@ public: CUTLASS_HOST_DEVICE static Status can_implement(ConvProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (input_channels % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } return Status::kSuccess; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index eb51ceb6..ac49cf07 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -66,7 +66,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + bool IsDeconv_ = false > class Conv3dFpropFilterTileAccessIteratorOptimized{ public: @@ -84,6 +85,7 @@ public: using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; @@ -172,11 +174,11 @@ public: CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); predicates_ |= (pred << s); } - if (filter_c_ >= problem_size.C) { + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { predicates_ = 0u; } @@ -214,7 +216,7 @@ public: filter_c_ += params_.filter_c_delta; } - if (filter_c_ >= problem_size_.C) { + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { predicates_ = 0; } @@ -259,8 +261,10 @@ public: CUTLASS_HOST_DEVICE static Status can_implement(Conv3dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (input_channels % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index d8c1d41a..d778046c 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -32,16 +32,6 @@ \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index e7c96d05..40ae2224 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -31,15 +31,6 @@ /*! \file \brief Helpers for printing cutlass/core objects */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #include diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 88d718e0..f3965283 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -33,16 +33,6 @@ \brief Basic include for CUTLASS. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/detail/helper_macros.hpp" diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index 10aad81d..a14696b2 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -61,7 +61,7 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule, + class EpilogueScheduleType, class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, class Enable = void > diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index f23a74e0..9eb4c4b1 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -723,7 +723,8 @@ public: } // Vectorized fragment loop with visitor callback entry point - int r2s_v = epi_n * size(tRS_rD_frg); + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rD_frg); CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index e3160fa1..b8cac856 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -634,17 +634,19 @@ struct Sm90AuxLoad< return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, ResidueMN residue_mn_, Params const& params_) + ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ResidueMN residue_mn_, Params const& params_) : tC_rAux(cute::forward(tC_rAux_)), tC_gAux(cute::forward(tC_gAux_)), + tC_cAux(tC_cAux_), residue_mn(residue_mn_), params(params_) {} RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ResidueMN residue_mn; Params const& params; @@ -657,8 +659,18 @@ struct Sm90AuxLoad< } } - if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { // (partially) in-bounds CTA tile - copy_aligned(tC_gAux, tC_rAux); + constexpr int V = cute::min(Alignment, decltype(max_common_vector(tC_rAux, tC_gAux))::value); + if constexpr (V > 0) { + using VecType = uint_bit_t; + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_cAux_vec = tC_cAux.compose(make_layout(Int{}, Int{})); // only works if vector is logically sequential + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_gAux_vec, tC_rAux_vec); + } + else { + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_gAux, tC_rAux); } } } @@ -672,9 +684,8 @@ struct Sm90AuxLoad< } } - if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { - copy_aligned(tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); - } + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); } } @@ -723,8 +734,8 @@ struct Sm90AuxLoad< } } - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); } }; diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 1a226a75..c37f2b9a 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -172,7 +172,7 @@ struct ReLu> { template struct Clamp { struct Arguments { - T lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::min(); + T lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); T upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); }; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 192bc6d1..7456ae8d 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -63,8 +63,52 @@ template struct kIsHeavy_member_or_false::type> { static constexpr bool value = Op::kIsHeavy; }; + } // namespace (anonymous) +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +struct EmptyArguments {}; + +template +struct ElementwiseOpDispatcher { + using Arguments = EmptyArguments; + + T op; + + CUTLASS_HOST_DEVICE + ElementwiseOpDispatcher(Arguments) {} + + template + CUTLASS_HOST_DEVICE + ValueType operator()(ValueType value) { + return op(value); + } +}; + +template +struct ElementwiseOpDispatcher> { + using Arguments = typename T::Arguments; + + Arguments args; + T op; + + CUTLASS_HOST_DEVICE + ElementwiseOpDispatcher(Arguments args_):args(args_) {} + + template + CUTLASS_HOST_DEVICE + ValueType operator()(ValueType value) { + return op(value, args); + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// This base class is meant to define the concept required of the /// EpilogueWithBroadcast::OutputOp template < @@ -95,9 +139,13 @@ public: using ElementwiseOp = ElementwiseOp_; using BinaryOp = BinaryOp_; + using ElementwiseOpDispatcher = detail::ElementwiseOpDispatcher; + using ElementwiseArguments = typename ElementwiseOpDispatcher::Arguments; + // Indicates that this epilogue applies only one binary operation static bool const kIsSingleSource = true; + using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentC = Array; @@ -127,6 +175,7 @@ public: ElementCompute beta; ///< scales source tensor ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + ElementwiseArguments elementwise; ///< Arguments for elementwise operation // // Methods @@ -142,8 +191,9 @@ public: CUTLASS_HOST_DEVICE Params( ElementCompute alpha, - ElementCompute beta - ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + ElementCompute beta, + ElementwiseArguments elementwise_ = ElementwiseArguments{} + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr), elementwise(elementwise_) { } @@ -157,8 +207,9 @@ public: CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr, - ElementCompute const *beta_ptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + ElementCompute const *beta_ptr, + ElementwiseArguments elementwise_ = ElementwiseArguments{} + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), elementwise(elementwise_) { } @@ -178,6 +229,7 @@ private: ElementCompute alpha_; ElementCompute beta_; + ElementwiseArguments const &elementwise_; bool skip_elementwise_; public: @@ -188,7 +240,7 @@ public: /// Constructor from Params CUTLASS_HOST_DEVICE - LinearCombinationBiasElementwise(Params const ¶ms) { + LinearCombinationBiasElementwise(Params const ¶ms): elementwise_(params.elementwise) { alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); @@ -290,7 +342,7 @@ public: FragmentC const &frag_C, FragmentCompute const &V) const { - ElementwiseOp elementwise_op; + ElementwiseOpDispatcher elementwise_op(elementwise_); BinaryOp binary_op; FragmentCompute tmp_Accum = NumericArrayConverter()(AB); @@ -322,7 +374,7 @@ public: FragmentAccumulator const &AB, FragmentCompute const &V) const { - ElementwiseOp elementwise_op; + ElementwiseOpDispatcher elementwise_op(elementwise_); BinaryOp binary_op; FragmentCompute tmp_Accum = NumericArrayConverter()(AB); diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index fac1a4af..5e1c847d 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -432,17 +432,12 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X } - // Convert floats back to INT - FragmentAccumulator scaled_accumulator; + // + // Convert float => ElementOutput_ with clamping + // + NumericArrayConverter destination_converter; - NumericArrayConverter compute_converter; - - scaled_accumulator = compute_converter(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; - - return destination_converter(scaled_accumulator); + return destination_converter(intermediate); } /// Computes linear scaling: D = alpha * accumulator @@ -466,17 +461,12 @@ public: intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum } - // Convert floats back to INT - FragmentAccumulator scaled_accumulator; + // + // Convert float => ElementOutput_ with clamping + // + NumericArrayConverter destination_converter; - NumericArrayConverter compute_converter; - - scaled_accumulator = compute_converter(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; - - return destination_converter(scaled_accumulator); + return destination_converter(intermediate); } }; diff --git a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h index a6e2a3dc..ff32f13b 100644 --- a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h +++ b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h @@ -156,23 +156,24 @@ public: NumericArrayConverter source_converter; NumericArrayConverter accumulator_converter; - ComputeFragment converted_source( + ComputeFragment converted_source{ source_converter(source.real), - source_converter(source.imag)); + source_converter(source.imag)}; - ComputeFragment converted_accumulator( + ComputeFragment converted_accumulator{ accumulator_converter(accumulator.real), - accumulator_converter(accumulator.imag)); - - // Perform binary operations - ComputeFragment intermediate; + accumulator_converter(accumulator.imag)}; multiplies > mul_op; multiply_add > mul_add_op; + // Perform binary operations + // complex multiply: I = beta * C - intermediate.real = mul_op(beta_.real(), converted_source.real); - intermediate.imag = mul_op(beta_.real(), converted_source.imag); + ComputeFragment intermediate { + mul_op(beta_.real(), converted_source.real), + mul_op(beta_.real(), converted_source.imag) + }; intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real); intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag); @@ -187,9 +188,9 @@ public: // Convert to destination numeric type NumericArrayConverter destination_converter; - return FragmentOutput( + return FragmentOutput{ destination_converter(intermediate.real), - destination_converter(intermediate.imag)); + destination_converter(intermediate.imag)}; } /// Computes linear scaling: D = alpha * accumulator + beta * source @@ -200,19 +201,19 @@ public: // Convert source to interal compute numeric type NumericArrayConverter accumulator_converter; - ComputeFragment converted_accumulator( + ComputeFragment converted_accumulator{ accumulator_converter(accumulator.real), - accumulator_converter(accumulator.imag)); + accumulator_converter(accumulator.imag)}; // Perform binary operations - ComputeFragment intermediate; - multiplies > mul_op; multiply_add > mul_add_op; // complex multiply-add: I = alpha * AB + I - intermediate.real = mul_op(alpha_.real(), converted_accumulator.real); - intermediate.imag = mul_op(alpha_.real(), converted_accumulator.imag); + ComputeFragment intermediate { + mul_op(alpha_.real(), converted_accumulator.real), + mul_op(alpha_.real(), converted_accumulator.imag) + }; intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); @@ -220,9 +221,9 @@ public: // Convert to destination numeric type NumericArrayConverter destination_converter; - return FragmentOutput( + return FragmentOutput{ destination_converter(intermediate.real), - destination_converter(intermediate.imag)); + destination_converter(intermediate.imag)}; } }; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h index 61892fd2..f3119fa4 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h @@ -64,11 +64,12 @@ #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h" #include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/epilogue_depthwise.h" @@ -89,7 +90,9 @@ template < typename OutputOp_, int ElementsPerAccess, bool ScatterD = false, - typename PermuteDLayout = layout::NoPermute + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 > struct DefaultEpilogueSimt { @@ -102,6 +105,8 @@ struct DefaultEpilogueSimt { using ElementOutput = typename OutputOp::ElementOutput; using LayoutC = typename WarpMmaSimt::LayoutC; using ElementAccumulator = typename WarpMmaSimt::ElementC; + static conv::StrideSupport const kStrideSupport = StrideSupport; + static int const kRank = Rank; // // Thread map @@ -116,13 +121,29 @@ struct DefaultEpilogueSimt { kElementsPerAccess >::Type; - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + static bool const UseCUDAStore = platform::is_same::value; + + using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput, ScatterD, - PermuteDLayout + PermuteDLayout, + UseCUDAStore >; + using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv< + OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout, + UseCUDAStore, + kRank + >; + + using OutputTileIterator = typename platform::conditional::type; + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< typename WarpMmaSimt::Shape, typename WarpMmaSimt::ThreadMma, @@ -389,7 +410,7 @@ struct DefaultDirectConvEpilogueSimt { typename WarpMmaSimt::Policy >; - using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLiner< + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear< OutputTileThreadMap, ElementAccumulator >; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index f3b006a1..1692cc30 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -66,6 +66,7 @@ #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" @@ -480,7 +481,9 @@ template < typename OutputOp_, int ElementsPerAccess, bool ScatterD = false, - typename PermuteDLayout = layout::NoPermute + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 > struct DefaultEpilogueTensorOp { @@ -493,6 +496,8 @@ struct DefaultEpilogueTensorOp { using ElementOutput = typename OutputOp::ElementOutput; using LayoutC = typename WarpMmaTensorOp::LayoutC; using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + static conv::StrideSupport const kStrideSupport = StrideSupport; + static int const kRank = Rank; // // Thread map @@ -508,7 +513,7 @@ struct DefaultEpilogueTensorOp { static bool const UseCUDAStore = platform::is_same::value; - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput, ScatterD, @@ -516,6 +521,19 @@ struct DefaultEpilogueTensorOp { UseCUDAStore >; + using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv< + OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout, + UseCUDAStore, + kRank + >; + + using OutputTileIterator = typename platform::conditional::type; + using AccumulatorFragmentIterator = typename platform::conditional::value, cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< typename WarpMmaTensorOp::Shape, diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h index 550354b3..d21382b4 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h @@ -114,7 +114,61 @@ struct DefaultEpilogueWithBroadcastSimt { typename Base::Padding >; }; +//////////////////////////////////////////////////////////////////////////////// +/// Defines sensible defaults for strided dgrad epilogues for SimtOps. +template < + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess, + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute +> +struct DefaultEpilogueWithBroadcastSimtStridedDgrad { + + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueSimtStridedDgrad< + Shape, + WarpMmaSimt, + OutputOp, + ElementsPerAccess + >; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< + typename Base::OutputTileThreadMap, + ElementOutput + >; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< + typename Base::OutputTileThreadMap, + ElementTensor + >; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast< + Shape, + WarpMmaSimt, + Base::kPartitionsK, + OutputTileIterator, + TensorTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding + >; +}; //////////////////////////////////////////////////////////////////////////////// /// Defines sensible defaults for epilogues for TensorOps. diff --git a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h index 5780623e..0f417485 100644 --- a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h +++ b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h @@ -70,7 +70,6 @@ struct ConvOutputIteratorParameter { static int const kTensorStrideIdx = (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradStrideIdx : 0); - CUTLASS_HOST_DEVICE static OutputIteratorLayout layout(const TensorRef & ref) { return ref.stride(kTensorStrideIdx); @@ -80,10 +79,59 @@ struct ConvOutputIteratorParameter { static OutputTensorCoord extent(ConvProblemSize problem_size) { return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); } - }; +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + using TensorLayout = layout::TensorNHWC; + using OutputIteratorLayout = layout::TensorNHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kFprop; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; + +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNDHWC; + using OutputIteratorLayout = layout::TensorNDHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kFprop; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; template < int InterleavedK, diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index e3330de8..14a85447 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -51,6 +51,8 @@ #include "cutlass/arch/arch.h" #include "cutlass/arch/memory.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" //////////////////////////////////////////////////////////////////////////////// @@ -102,10 +104,10 @@ public: /// Fragment object using Fragment = Array< - Element, - ThreadMap::Iterations::kColumn * - ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; /// Memory access size @@ -123,13 +125,27 @@ public: Params() { } CUTLASS_HOST_DEVICE - Params(Layout const &layout): + Params(Layout const &layout): PredicatedTileIteratorParams( layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, make_OutputTileThreadMapDesc() ) { } + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + conv::Conv2dProblemSize const &problem_size): + Params(layout) + { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + conv::Conv3dProblemSize const &problem_size): + Params(layout) + { } + CUTLASS_HOST_DEVICE Params(Base const &base) : Base(base) { } @@ -202,7 +218,7 @@ private: int state_[3]; /// Scatter indices - int const *indices_; + int const *indices_; /// PermuteDLayout PermuteDLayout permute_layout_; @@ -253,7 +269,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - mask_.predicates[c] = ((thread_offset.column() + mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); } @@ -267,8 +283,8 @@ public: } // Initialize byte_pointer_ - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; if (ScatterD) { @@ -306,7 +322,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); int row_offset = row * ThreadMap::Delta::kRow @@ -330,7 +346,7 @@ public: bool guard = row_guard && mask_.predicates[column]; cutlass::arch::global_load< - AccessType, + AccessType, sizeof(AccessType) >( frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + @@ -380,11 +396,11 @@ public: CUTLASS_PRAGMA_UNROLL for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; bool row_guard = ((row_offset + thread_start_row_) < extent_row_); @@ -426,7 +442,7 @@ public: (void *)&memory_pointer[0], guard); } - + if (!PermuteD) { memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); } @@ -649,7 +665,7 @@ public: } thread_start_row_ += ThreadMap::Shape::kRow; - + if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; @@ -663,7 +679,7 @@ public: store_byte_pointer_ += params_.advance_group; } - thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; if (state_[1] == ThreadMap::Count::kGroup) { @@ -679,7 +695,7 @@ public: store_byte_pointer_ += params_.advance_cluster; } - thread_start_row_ += ThreadMap::Count::kGroup * + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { @@ -1121,6 +1137,14 @@ public: initialize(layout.stride()); } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + conv::Conv2dProblemSize const &problem_size): + Params(layout) + { } + }; /// Mask object diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h new file mode 100644 index 00000000..c3c722bc --- /dev/null +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/permute.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIteratorConv | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not + bool UseCUDAStore = false, + int Rank = 4 +> +class PredicatedTileIteratorConv { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + static int const kRank = Rank; + using Layout = typename platform::conditional::type; + + using Stride = typename Layout::Stride; + static int const kStrideRank = Layout::kStrideRank; + + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using MappedLayout = layout::RowMajor; + using Index = typename MappedLayout::Index; + using LongIndex = typename MappedLayout::LongIndex; + using TensorCoord = typename MappedLayout::TensorCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static bool constexpr PermuteD = !layout::is_trivial_permute; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod[kStrideRank - 1]; + Stride tensor_stride; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, conv::Conv2dProblemSize const &problem_size): + PredicatedTileIteratorParams( + layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) { + divmod[0] = FastDivmod(problem_size.Q); + divmod[1] = FastDivmod(problem_size.P); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + tensor_stride[i] = layout.stride()[i]; + } + } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, conv::Conv3dProblemSize const &problem_size): + PredicatedTileIteratorParams( + layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) { + divmod[0] = FastDivmod(problem_size.Q); + divmod[1] = FastDivmod(problem_size.P); + divmod[2] = FastDivmod(problem_size.Z); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + tensor_stride[i] = layout.stride()[i]; + } + } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load(). + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorConv( + Params const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + // Initialize byte_pointer_ + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + Stride tensor_coord = CoordinateDecompositionLittleEndian(row_offset + thread_start_row_, params_.divmod); + + LongIndex tensor_offset = dot(tensor_coord, params_.tensor_stride); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess + tensor_offset / kElementsPerAccess], + guard); + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + Stride tensor_coord = CoordinateDecompositionLittleEndian((row_offset + thread_start_row_), params_.divmod); + + LongIndex tensor_offset = dot(tensor_coord, params_.tensor_stride); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[tensor_offset / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[tensor_offset / kElementsPerAccess], + guard); + } + + memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); + } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorConv &operator++() { + + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow + * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; + } + } + } + + return *this; + } + + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorConv &operator+=(int increment) + { + // Row + state_[0] += increment; + int increment_row = state_[0] / ThreadMap::Count::kRow; + state_[0] = state_[0] % ThreadMap::Count::kRow; + + thread_start_row_ += (ThreadMap::Shape::kRow * increment); + + // Group + state_[1] += increment_row; + int increment_group = state_[1] / ThreadMap::Count::kGroup; + state_[1] = state_[1] % ThreadMap::Count::kGroup; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * + ThreadMap::Count::kRow * + increment_row; + + // Cluster + state_[2] += increment_group; + int increment_cluster = state_[2] / ThreadMap::Count::kCluster; + state_[2] = state_[2] % ThreadMap::Count::kCluster; + + thread_start_row_ += + ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * + ThreadMap::Shape::kRow * + increment_group; + + // Tile + thread_start_row_ += + ThreadMap::Shape::kGroup * + ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * + ThreadMap::Shape::kTile * + increment_cluster; + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h index 2b412bf1..5e9aa22b 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -32,16 +32,6 @@ \brief */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" @@ -257,8 +247,6 @@ struct PredicatedTileIteratorParams { } }; - - /////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h similarity index 98% rename from include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h rename to include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h index 79a91f75..5af6997e 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h @@ -66,7 +66,7 @@ namespace threadblock { template ::value / 8> -class SharedLoadIteratorPitchLiner { +class SharedLoadIteratorPitchLinear { public: using ThreadMap = ThreadMap_; using Element = Element_; @@ -123,7 +123,7 @@ class SharedLoadIteratorPitchLiner { /// Constructor CUTLASS_DEVICE - SharedLoadIteratorPitchLiner(TensorRef ref, int thread_idx) + SharedLoadIteratorPitchLinear(TensorRef ref, int thread_idx) : byte_pointer_(reinterpret_cast(ref.data())), stride_((ref.stride(0) * sizeof_bits::value) / 8), base_smem_address_(0) { diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 3e842d6a..84fb06de 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -28,15 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 0d925f26..a2d062a0 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -33,15 +33,6 @@ \brief Defines a class for using IEEE half-precision floating-point types in host or device code. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 27c8de2d..964d2ff3 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -33,20 +33,13 @@ This is inspired by the Standard Library's header. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" +#include + #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) @@ -216,6 +209,35 @@ struct magnitude_squared_difference { } }; +// Computes the reciprocal square root +template +struct inverse_square_root; + +template <> +struct inverse_square_root { + CUTLASS_HOST_DEVICE + float operator()(float const &lhs) const { +#if defined(__CUDA_ARCH__) + return rsqrtf(lhs); +#else + return 1.f / std::sqrt(lhs); +#endif + } +}; + +template <> +struct inverse_square_root { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &lhs) const { +#if defined(__CUDA_ARCH__) + auto result = hrsqrt(reinterpret_cast<__half const &>(lhs)); + return reinterpret_cast(result); +#else + return half_t(1.f / std::sqrt(half_t::convert(lhs))); +#endif + } +}; + /// Divides template struct divides { @@ -546,8 +568,6 @@ struct bit_xor { } }; - - ////////////////////////////////////////////////////////////////////////////////////////////////// /// Atomic reductions diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 29ee9605..4613f7bf 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -1093,7 +1093,7 @@ private: else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = thread_mma.partition_A(sS); - Tensor tCrS = make_fragment_like(thread_mma.partition_fragment_A(sS(_,_,Int<0>{}))); + Tensor tCrS = make_tensor(thread_mma.partition_fragment_A(sS(_,_,Int<0>{})).shape()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); @@ -1101,7 +1101,7 @@ private: else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsZ = thread_mma.partition_A(sZ); - Tensor tCrZ = make_fragment_like(thread_mma.partition_fragment_A(sZ(_,_,Int<0>{}))); + Tensor tCrZ = make_tensor(thread_mma.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); } else { diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 750b16a6..8d045d6e 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -96,7 +96,7 @@ public: using ElementB = typename GemmKernel::ElementB; using ElementC = typename GemmKernel::ElementC; using ElementD = typename GemmKernel::ElementD; - using ElementAccumulator = typename GemmKernel::TiledMma::ValTypeC; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; using DispatchPolicy = typename GemmKernel::DispatchPolicy; using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; @@ -361,9 +361,13 @@ public: CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { - launch_result = cuda_adapter->launch( - grid, cluster, block, smem_size, stream, kernel_params, 0 - ); + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + 0); } else { return Status::kErrorInternal; diff --git a/include/cutlass/gemm/gemm_enumerated_types.h b/include/cutlass/gemm/gemm_enumerated_types.h index efc93e55..66aae898 100644 --- a/include/cutlass/gemm/gemm_enumerated_types.h +++ b/include/cutlass/gemm/gemm_enumerated_types.h @@ -32,16 +32,6 @@ \brief Defines common types used for all GEMM-like operators. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index f3f781a6..08b30c74 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief + \brief */ #pragma once @@ -177,8 +177,8 @@ public: int const *ptr_scatter_D_indices = nullptr) : UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), @@ -486,18 +486,18 @@ public: int offset_k = 0; int problem_size_k = params.problem_size.k(); - ElementA *ptr_A = static_cast(params.ptr_A); + ElementA *ptr_A = static_cast(params.ptr_A); ElementB *ptr_B = static_cast(params.ptr_B); // // Fetch pointers based on mode. // - if (params.mode == GemmUniversalMode::kGemm || + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; @@ -566,10 +566,10 @@ public: // Compute threadblock-scoped matrix multiply-add mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, accumulators); // @@ -592,13 +592,13 @@ public: int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_C = static_cast(params.ptr_C); ElementC *ptr_D = static_cast(params.ptr_D); // // Fetch pointers based on mode. // - + // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); @@ -606,7 +606,7 @@ public: // If performing a reduction via split-K, fetch the initial synchronization if (params.grid_tiled_shape.k() > 1) { - + // Fetch the synchronization lock initially but do not block. semaphore.fetch(); @@ -647,14 +647,14 @@ public: ); Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, + shared_storage.epilogue, + thread_idx, + warp_idx, lane_idx); // Wait on the semaphore - this latency may have been covered by iterator construction if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. if (threadblock_tile_offset.k()) { iterator_C = iterator_D; @@ -666,11 +666,11 @@ public: // Execute the epilogue operator to update the destination tensor. epilogue( - output_op, - iterator_D, - accumulators, - iterator_C); - + output_op, + iterator_D, + accumulators, + iterator_C); + // // Release the semaphore // @@ -687,7 +687,7 @@ public: // Otherwise, the semaphore is incremented lock = threadblock_tile_offset.k() + 1; } - + semaphore.release(lock); } } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 9229feee..877d2c1d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -69,7 +69,6 @@ public: using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index 1da5b6d3..abf79e84 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -70,7 +70,6 @@ public: using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 0cc3ffc8..1630583f 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -35,16 +35,6 @@ \brief Parameters structures for persistent tile schedulers */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #include "cutlass/coord.h" #include "cutlass/kernel_hardware_info.h" #include "cutlass/workspace.h" diff --git a/include/cutlass/gemm/thread/mma_sm60.h b/include/cutlass/gemm/thread/mma_sm60.h index 35cf3fb7..5e217898 100644 --- a/include/cutlass/gemm/thread/mma_sm60.h +++ b/include/cutlass/gemm/thread/mma_sm60.h @@ -147,9 +147,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; + Array tmp { ptr_D[n*Shape::kM/2 + m] }; mma( tmp, @@ -157,7 +155,7 @@ struct Mma_HFMA2 < ptr_B[n*Shape::kK + k], tmp); - ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; + ptr_D[n*Shape::kM/2 + m] = tmp; } } } @@ -239,9 +237,7 @@ struct Mma_HFMA2< CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; Array tmp_B; tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); @@ -253,7 +249,7 @@ struct Mma_HFMA2< tmp_B, tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } @@ -335,10 +331,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) { - Array tmp; - Array *ptr_tmp = &tmp; - - ptr_tmp[0] = ptr_D[m + n * Shape::kM/2]; + Array tmp { ptr_D[m + n * Shape::kM/2] }; mma( tmp, @@ -346,7 +339,7 @@ struct Mma_HFMA2 < ptr_B[k * Shape::kN + n], tmp); - ptr_D[m + n * Shape::kM/2] = ptr_tmp[0]; + ptr_D[m + n * Shape::kM/2] = tmp; } } } @@ -428,9 +421,7 @@ struct Mma_HFMA2< CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; mma( tmp, @@ -438,7 +429,7 @@ struct Mma_HFMA2< ptr_B[k*Shape::kN/2 + n], tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } @@ -521,9 +512,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; + Array tmp { ptr_D[n*Shape::kM/2 + m] }; Array tmp_A; tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); @@ -535,7 +524,7 @@ struct Mma_HFMA2 < ptr_B[n*Shape::kK + k], tmp); - ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; + ptr_D[n*Shape::kM/2 + m] = tmp; } } } @@ -617,9 +606,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; Array tmp_B; tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); @@ -631,7 +618,7 @@ struct Mma_HFMA2 < tmp_B, tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } @@ -713,9 +700,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; + Array tmp { ptr_D[n*Shape::kM/2 + m] }; Array tmp_A; tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); @@ -727,7 +712,7 @@ struct Mma_HFMA2 < ptr_B[k*Shape::kN + n], tmp); - ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; + ptr_D[n*Shape::kM/2 + m] = tmp; } } } @@ -810,9 +795,7 @@ struct Mma_HFMA2< CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; mma( tmp, @@ -820,7 +803,7 @@ struct Mma_HFMA2< ptr_B[k*Shape::kN/2 + n], tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index 596fe5a4..b79e587d 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -32,16 +32,6 @@ \brief Implements streamk threadblock mapping blockIdx to GEMM problems. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/half.h b/include/cutlass/half.h index e22c8be3..c203e6cb 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -34,16 +34,6 @@ device code. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #ifndef CUTLASS_ENABLE_F16C diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index e06a0491..1a9728e7 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -34,16 +34,6 @@ device code. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index b69399ff..62dcb8b4 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -30,16 +30,6 @@ **************************************************************************************************/ #pragma once -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #if !defined(__CUDACC_RTC__) #include "cuda_runtime.h" diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index 8ece4378..32aa17a5 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -38,16 +38,6 @@ defined in cutlass/tensor_ref.h. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/layout/pitch_linear.h b/include/cutlass/layout/pitch_linear.h index 4063a988..8c9540f4 100644 --- a/include/cutlass/layout/pitch_linear.h +++ b/include/cutlass/layout/pitch_linear.h @@ -32,16 +32,6 @@ \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 4cef8333..2a3a0954 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -33,16 +33,6 @@ \brief Boost-like numeric conversion operator for CUTLASS numeric types */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if !defined(__CUDACC_RTC__) @@ -848,18 +838,21 @@ struct NumericArrayConverter result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + Array result; reinterpret_cast<__half2 &>(result) = __float22half2_rn(reinterpret_cast(source)); + return result; #else NumericConverter convert_; + // NOTE: cutlass::Array is NOT an aggregate type and + // below `{}` does NOT conduct zero initialization. Below `{}` will + // conduct default initialization (calling default ctr). We use this syntax + // to resolve compiler warning on uninitialized member variable. + Array result{}; result[0] = convert_(source[0]); result[1] = convert_(source[1]); + return result; #endif - - return result; } CUTLASS_HOST_DEVICE @@ -879,17 +872,19 @@ struct NumericArrayConverter { CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - reinterpret_cast(result) = __half22float2(reinterpret_cast<__half2 const &>(source)); + float2 result2 = __half22float2(reinterpret_cast<__half2 const &>(source)); + return { + float{result2.x}, + float{result2.y} + }; #else NumericConverter convert_; - result[0] = convert_(source[0]); - result[1] = convert_(source[1]); + return { + convert_(source[0]), + convert_(source[1]) + }; #endif - - return result; } CUTLASS_HOST_DEVICE @@ -1482,7 +1477,7 @@ struct NumericArrayConverterPacked4Element { for (int i = 0; i < 4; ++i) { if (platform::is_same::value) { result[i] = convert_(s[i]); - } + } else { // conjugate result[i] = conj(convert_(s[i])); } @@ -2306,14 +2301,69 @@ template < struct NumericArrayConverter : public PackedNumericArrayConverter {}; - - ///////////////////////////////////////////////////////////////////////////////////////////////// - /// Partial specialization for Array <= Array /// Conversion is performed with saturation regardless of setting of /// the `Round` template parameter. +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + // Convert to int to int8_t + NumericConverter destination_converter; + result_type result; + result[0] = destination_converter(source[0]); + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +// To convert a FP32 to Int that has less than 32 bits, we need to convert it to int32 first. +template < + typename T, + int N, + FloatRoundStyle Round +> +struct NumericArrayFP32ToIntConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + static_assert(platform::numeric_limits::is_integer, "the dest type has to be int."); + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + // Convert float to int + Array temporary; + + NumericArrayConverter compute_converter; + temporary = compute_converter(source); + + // Convert to int to int8_t + NumericArrayConverter destination_converter; + return destination_converter(temporary); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + + template < int N, FloatRoundStyle Round @@ -2322,19 +2372,74 @@ struct NumericArrayConverter { using result_type = Array; using source_type = Array; - static FloatRoundStyle const round_style = Round; CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - // Convert float to int - Array temporary; + NumericArrayFP32ToIntConverter converter; + return converter(source); + } - NumericArrayConverter compute_converter; - temporary = compute_converter(source); + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; - // Convert to int to int8_t - NumericArrayConverter destination_converter; - return destination_converter(temporary); +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); } CUTLASS_HOST_DEVICE @@ -2508,7 +2613,7 @@ namespace detail { template CUTLASS_DEVICE static void convert_helper( - typename ArrayConverter::result_type& result, + typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { using ElementRes = typename ArrayConverter::result_type::Element; @@ -2530,14 +2635,14 @@ namespace detail { static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); - static_assert(cutlass::platform::is_same::value, + static_assert(cutlass::platform::is_same::value, "ResultVectorArray must have the same type ArrayConverter::result_type"); - static_assert(cutlass::platform::is_same::value, + static_assert(cutlass::platform::is_same::value, "SourceVectorArray must have the same type ArrayConverter::result_type"); static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); - + constexpr int vector_width = ResultVectorArray::kElements; static_assert(ispow2(vector_width), "Vector width must be a power of 2"); @@ -2569,8 +2674,8 @@ namespace detail { public: /* - A method to convert vectors of elements using the packed_convert method of the converter. - + A method to convert vectors of elements using the packed_convert method of the converter. + Converters using this class must implement packed convert and support 1 or more vector conversions. */ template @@ -2651,7 +2756,7 @@ private: uint32_t final_prmt_idx = final_prmt_base | sign; // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value - // as the index to prmt. + // as the index to prmt. // It first select both the positive and negative candidates, then uses the sign bit to // select the correct candidate. asm volatile( @@ -2675,8 +2780,8 @@ public: static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; @@ -2684,7 +2789,7 @@ public: CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2771,7 +2876,7 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 1, 2, 4 or 8 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements PackedResultType r; @@ -2788,23 +2893,23 @@ private: return r; } - friend class detail::VectorizedConverter; + friend class detail::VectorizedConverter; public: CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2844,7 +2949,7 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + PackedResultType r; // View the input as reg uint32_t src_reg = to_reg(source); @@ -2875,15 +2980,15 @@ public: result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2923,12 +3028,12 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + PackedResultType r; // View the input as reg uint32_t src_reg = to_reg(source); - // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores // the result in r (without introducing extra cvt.u32.u8 instruction) uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; uint32_t* result_as_int = reinterpret_cast(&r); @@ -2948,15 +3053,15 @@ public: static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3010,10 +3115,10 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; + RegArray r; // View the input as reg uint32_t src_reg = to_reg(source); @@ -3034,7 +3139,7 @@ private: " prmt.b32 %0, %1, %2, %3;\n" "}\n" : "=r"(r[ii]) - : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); + : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); } // The below XOR does the following: @@ -3057,7 +3162,7 @@ private: " lop3.b32 %0, %0, %1, %2, %3;\n" "}\n" : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); } // We will issue 2 hfmas that do the following: @@ -3087,23 +3192,23 @@ private: return reinterpret_cast(r); } - friend class detail::VectorizedConverter; + friend class detail::VectorizedConverter; public: CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3145,10 +3250,10 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; + RegArray r; #if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) auto result = reinterpret_cast(r); @@ -3176,18 +3281,18 @@ private: // In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve // the same result as add.s16x2 instruction. // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) - // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to + // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to // three predefined constant values as follows: // ta = 0xF0; // tb = 0xCC; // tc = 0xAA; // kImmLut = F(ta, tb, tc); - // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA - static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; + // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA + static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; for (int ii = 0; ii < RegArray::kElements; ++ii) { // The bit-wise operation executed below is `r[ii] = (r[ii] & 0x03FF03FF) ^ 0x66006600;` - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(r[ii]) : "r"(r[ii]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); } @@ -3209,14 +3314,14 @@ public: result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3256,11 +3361,11 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; - + RegArray r; + // View the input as reg uint32_t src_reg = to_reg(source); uint32_t const prmt_indices[2] = {0x5150, 0x5352}; @@ -3289,15 +3394,15 @@ public: result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3352,10 +3457,10 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; + RegArray r; // View the input as reg uint32_t src_reg = to_reg(source); @@ -3371,7 +3476,7 @@ private: " prmt.b32 %0, %1, %2, %3;\n" "}\n" : "=r"(r[ii]) - : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); } // The below XOR does the following: @@ -3390,7 +3495,7 @@ private: " lop3.b32 %0, %0, %1, %2, %3;\n" "}\n" : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); } // We will issue 2 bfmas that do the following: @@ -3400,7 +3505,7 @@ private: // This is the BF16 {136, 136} represented as an integer. static constexpr uint32_t bias_rep = 0x43084308; const __nv_bfloat162& bias = reinterpret_cast(bias_rep); - + CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < RegArray::kElements; ++ii) { __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); @@ -3410,23 +3515,23 @@ private: return reinterpret_cast(r); } - friend class detail::VectorizedConverter; + friend class detail::VectorizedConverter; public: CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3466,7 +3571,7 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + NumericArrayConverter convert_int8_to_f32; Array tmp = convert_int8_to_f32(source); NumericArrayConverter convert_f32_to_bf16; @@ -3481,15 +3586,15 @@ public: result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3529,7 +3634,7 @@ private: (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + NumericArrayConverter convert_uint8_to_f32; Array tmp = convert_uint8_to_f32(source); NumericArrayConverter convert_f32_to_bf16_; @@ -3543,15 +3648,15 @@ public: static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3705,7 +3810,7 @@ struct PackPredicates { int word_idx = (i / kWordSize); int bit_idx = (i % kWordSize); - uint8_t mask = ((predicates[i] ? 1u : 0u) << bit_idx); + uint8_t mask = static_cast((predicates[i] ? 1u : 0u) << bit_idx); bytes[word_idx] = (bytes[word_idx] | mask); } return packed; diff --git a/include/cutlass/numeric_size.h b/include/cutlass/numeric_size.h index 46f343aa..42bc418a 100644 --- a/include/cutlass/numeric_size.h +++ b/include/cutlass/numeric_size.h @@ -33,16 +33,6 @@ \brief Top-level include for all CUTLASS numeric types. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index cb7c2087..5519fbe7 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -32,15 +32,6 @@ \file \brief Top-level include for all CUTLASS numeric types. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index b7d9e093..2ab7ae04 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -264,7 +264,6 @@ public : dim3 block_id = cute::block_id_in_cluster(); auto cluster_size = cute::size(cluster_shape); static constexpr int MaxClusterSize = 16; - static_assert(cluster_size <= MaxClusterSize, "ERROR : Cluster size too large !" ); // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 927f80cb..ba74ae72 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -95,16 +95,6 @@ * counterparts (or trivially find-and-replace their occurrences in code text). */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - //----------------------------------------------------------------------------- // Dependencies //----------------------------------------------------------------------------- @@ -159,7 +149,7 @@ /// builtin_unreachable #if !defined(CUTLASS_GCC_UNREACHABLE) -# if defined(__clang__) || defined(__GNUC__) +# if defined(__GNUC__) # define CUTLASS_GCC_UNREACHABLE __builtin_unreachable() # else # define CUTLASS_GCC_UNREACHABLE @@ -950,7 +940,6 @@ struct numeric_limits { static constexpr bool is_integer = true; }; -#if !defined(__CUDACC_RTC__) template <> struct numeric_limits { CUTLASS_HOST_DEVICE @@ -958,7 +947,6 @@ struct numeric_limits { static constexpr bool is_integer = false; static constexpr bool has_infinity = true; }; -#endif /// std::float_round_style using CUTLASS_STL_NAMESPACE::float_round_style; diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h index 22699a28..c67af387 100755 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h @@ -32,16 +32,6 @@ \brief */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index d1e6e646..0a41e95c 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -32,15 +32,6 @@ \file \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/workspace.h b/include/cutlass/workspace.h index 82ff77c1..6dc0141c 100644 --- a/include/cutlass/workspace.h +++ b/include/cutlass/workspace.h @@ -32,16 +32,6 @@ \brief Utilities for initializing workspaces */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if !defined(__CUDACC_RTC__) diff --git a/media/docs/build/building_with_clang_as_host_compiler.md b/media/docs/build/building_with_clang_as_host_compiler.md index d14eb20e..54b2c78e 100644 --- a/media/docs/build/building_with_clang_as_host_compiler.md +++ b/media/docs/build/building_with_clang_as_host_compiler.md @@ -36,17 +36,17 @@ is the following error when attempting to use clang: ## Required CMake options The Clang build requires specifying the following CMake options. -Replace `` with the path to your `clang++` executable, -and replace `` with the path to your `clang` executable -(which must have the same version as your `clang++` executable). -You may use `clang++` resp. `clang` directly if they are in your `PATH`. +Replace `` with the path to your `clang++` executable. +You may use `clang++` directly if it is in your `PATH`. * `CMAKE_CXX_COMPILER=` * `CMAKE_CUDA_HOST_COMPILER=` -* `CMAKE_C_COMPILER=` -Please note that both `CMAKE_CXX_COMPILER` and `CMAKE_C_COMPILER` -must be set, even though CUTLASS is a C++ project, not a C project. +One must set both! It's not enough just to set the `CXX` environment +variable, for example. Symptoms of only setting `CMAKE_CXX_COMPILER` +(or only setting the `CXX` environment variable) include `cc1plus` +(GCC's compiler executable) reporting build errors due to it not +understanding Clang's command-line options. Users can also specify a particular CUDA Toolkit version by setting the CMake option `CMAKE_CUDA_COMPILER` diff --git a/media/docs/cute/02_layout_algebra.md b/media/docs/cute/02_layout_algebra.md index f9c9d2fb..3b70252b 100644 --- a/media/docs/cute/02_layout_algebra.md +++ b/media/docs/cute/02_layout_algebra.md @@ -317,23 +317,25 @@ The `complement` of a layout attempts to find another layout that represents the You can find many examples and checked post-conditions in [the `complement` unit test](../../../test/unit/cute/core/complement.cpp). The post-conditions include ```cpp -// @post cosize(make_layout(@a layout_a, @a result))) >= @a cosize_hi -// @post cosize(@a result) >= round_up(@a cosize_hi, cosize(@a layout_a)) +// @post cosize(make_layout(@a layout_a, @a result))) >= size(@a cotarget) +// @post cosize(@a result) >= round_up(size(@a cotarget), cosize(@a layout_a)) // @post for all i, 1 <= i < size(@a result), // @a result(i-1) < @a result(i) // @post for all i, 1 <= i < size(@a result), // for all j, 0 <= j < size(@a layout_a), // @a result(i) != @a layout_a(j) -Layout complement(LayoutA const& layout_a, Integral const& cosize_hi) +Layout complement(LayoutA const& layout_a, Shape const& cotarget) ``` -That is, the complement `R` of a layout `A` with respect to an integer `M` satisfies the following properties. -1. The size (and cosize) of `R` is *bounded* by `M`. +That is, the complement `R` of a layout `A` with respect to a Shape (IntTuple) `M` satisfies the following properties. +1. The size (and cosize) of `R` is *bounded* by `size(M)`. 2. `R` is *ordered*. That is, the strides of `R` are positive and increasing. This means that `R` is unique. 3. `A` and `R` have *disjoint* codomains. `R` attempts to "complete" the codomain of `A`. +The `cotarget` parameter above is most commonly an integer -- you can see we only use `size(cotarget)` above. However, sometimes it is useful to specify an integer that has static properties. For example, `28` is a dynamic integer and `(_4,7)` is a shape with size `28` that is statically known to be divisible by `_4`. Both will produce the same `complement` mathematically, but the extra information can used by `complement` to preserve the staticness of the result as much as possible. + ### Complement Examples -`complement` is most effective on static shapes and strides, so consider all integers below to be static. Similar examples for dynamic shapes and strides can be found in the unit test. +`complement` is most effective on static shapes and strides, so consider all integers below to be static. Similar examples for dynamic shapes and strides as well as IntTuple `cotarget` can be found in [the unit test](../../../test/unit/cute/core/complement.cpp). * `complement(4:1, 24)` is `6:4`. Note that `(4,6):(1,4)` has cosize `24`. The layout `4:1` is effectively repeated 6 times with `6:4`. @@ -425,9 +427,9 @@ Layout Shape : (M, N, L, ...) Tiler Shape : logical_divide : ((TileM,RestM), (TileN,RestN), L, ...) -zipped_divide : ((TileM,TileN,...), (RestM,RestN,L,...)) -tiled_divide : ((TileM,TileN,...), RestM, RestN, L, ...) -flat_divide : (TileM, TileN, ..., RestM, RestN, L, ...) +zipped_divide : ((TileM,TileN), (RestM,RestN,L,...)) +tiled_divide : ((TileM,TileN), RestM, RestN, L, ...) +flat_divide : (TileM, TileN, RestM, RestN, L, ...) ``` For example, the `zipped_divide` function applies `logical_divide`, and then gathers the "subtiles" into a single mode and the "rest" into a single mode. diff --git a/media/docs/fundamental_types.md b/media/docs/fundamental_types.md index 8e17ad57..8bef0702 100644 --- a/media/docs/fundamental_types.md +++ b/media/docs/fundamental_types.md @@ -63,13 +63,12 @@ template < typename T, // element type int N // number of elements > -class Array; +struct Array; ``` `Array` defines a statically sized array of elements of type _T_ and size _N_. This class is similar to -[`std::array<>`](https://en.cppreference.com/w/cpp/container/array) in the Standard Library with two notable exceptions: -* constructors for each element may not be called -* partial specializations exist to pack or unpack elements smaller than one byte. +[`std::array<>`](https://en.cppreference.com/w/cpp/container/array) in the Standard Library with one notable exception: +partial specializations exist to pack or unpack elements smaller than one byte. `Array<>` is intended to be a convenient and uniform container class to store arrays of numeric elements regardless of data type or vector length. The storage needed is expected to be the minimum necessary given the logical size of each numeric type in bits (numeric types smaller than one byte are densely packed). Nevertheless, the size reported by `sizeof(Array)` is always an integer multiple of bytes. diff --git a/media/docs/profiler.md b/media/docs/profiler.md index d318e3d0..34282925 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -210,7 +210,6 @@ GEMM [int] --inst_k,--instruction-shape::k Math instruction shape in the K dimension [int] --min_cc,--minimum-compute-capability Minimum device compute capability [int] --max_cc,--maximum-compute-capability Maximum device compute capability - Examples: Profile a particular problem size: diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index 55f4e2f1..62ac6c27 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -1654,7 +1654,7 @@ class GemmOperationBase: extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( element_a=DataTypeNames[self.A.element], element_b=DataTypeNames[self.B.element], - element_acc=DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_acc=DataTypeNames[self.accumulator_type()], element_c=DataTypeNames[self.C.element], element_d=DataTypeNames[self.epilogue_functor.element_output], core_name=self.core_name()) diff --git a/python/cutlass/emit/common.py b/python/cutlass/emit/common.py index ef724c04..87025eea 100644 --- a/python/cutlass/emit/common.py +++ b/python/cutlass/emit/common.py @@ -118,16 +118,18 @@ cutlass::Status ${name}_kernel_run( typename DeviceKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K, L}, // problem size - A, // ptrA - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A - B, // ptrB - cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B { + A, // ptrA + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A + B, // ptrB + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B + }, + { + {alpha, beta}, C, // ptrC cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C D, // ptrD cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D - {alpha, beta}, }, hw_info }; diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index 613311ac..ac13e866 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -232,7 +232,7 @@ _PYTORCH_GEMM_INCLUDES = { #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/util/packed_stride.hpp" """, } @@ -583,7 +583,11 @@ setup( '${name}_kernel.cu', ], include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'], - extra_compile_args=['-std=c++17'] + extra_compile_args={ + 'cxx': ['-std=c++17'], + 'nvcc': ['-std=c++17', ${extra_compile_args}], + }, + libraries=['cuda'] ), ], cmdclass={ @@ -593,7 +597,7 @@ setup( """ -def _generate_setup(name: str, sourcedir: str): +def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""): """ Generates a setup.py file for the extension @@ -601,10 +605,12 @@ def _generate_setup(name: str, sourcedir: str): :type name: str :param sourcedir: directory to which generated source files should be written :type sourcedir: str + :param extra_compile_args: additional arguments to pass to setup.py + :type extra_args: str """ setup_py_file = os.path.join(sourcedir, "setup.py") setup_source = SubstituteTemplate( - _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH} + _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args} ) with open(setup_py_file, "w") as outfile: outfile.write(setup_source) @@ -696,6 +702,7 @@ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): os.path.join(CUTLASS_PATH, "include"), os.path.join(CUTLASS_PATH, "tools/util/include"), ], + extra_ldflags=["-lcuda"], verbose=(logger.level == logging.DEBUG) ) return jitmodule @@ -759,7 +766,10 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = "" with open(cpp_file, "w") as outfile: outfile.write(cpp_source) - _generate_setup(name, sourcedir) + extra_compile_args = "" + if cc == 90: + extra_compile_args = "'--generate-code=arch=compute_90a,code=[sm_90a]'" + _generate_setup(name, sourcedir, extra_compile_args) if jit: return _jit(name, cc, cpp_file, cuda_file) diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 3ef021f7..7c16cc68 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -137,9 +137,9 @@ class KernelsForDataType: # Finally, go through all available alignment combinations and find # one for which all values are less than those passed in. key = None - alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) + alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) for align_A, align_B, align_C in alignments: - if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C: + if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0: key = f"{align_A} {align_B} {align_C}" break diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index e486a431..e74c4078 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -712,4 +712,4 @@ class Gemm(OperationBase): if sync: arguments.sync() - return arguments \ No newline at end of file + return arguments diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 9850b68a..f739c15a 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -205,7 +205,7 @@ class GemmOperation: extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( element_a = DataTypeNames[self.A.element], element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_acc = DataTypeNames[self.accumulator_type()], element_c = DataTypeNames[self.C.element], element_d = DataTypeNames[self.D.element], core_name = self.core_name()) @@ -216,7 +216,7 @@ class GemmOperation: datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( element_a = DataTypeNames[self.A.element], element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_acc = DataTypeNames[self.accumulator_type()], element_c = DataTypeNames[self.C.element], element_d = DataTypeNames[self.D.element]) return datatype_name @@ -744,7 +744,7 @@ using ${operation_name}_mainloop = cute::Shape, cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, ${stages}, - ${kernel_schedule} + ${kernel_schedule} >::CollectiveOp; // Gemm operator ${operation_name} @@ -817,8 +817,9 @@ ${compile_guard_end} else: epilogue_functor = self.epilogue_functor.emit_declaration() # - element_a = DataTypeTag[operation.A.element] - element_b = DataTypeTag[operation.B.element] + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] values = { 'operation_name': operation.procedural_name(), diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 9f1045f3..0ac604e7 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -967,6 +967,7 @@ class ConvOperation3x: def configuration_name(self): prefix = 'cutlass3x' + arch = self.arch opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] tbm = self.tile_description.tile_shape[0] tbn = self.tile_description.tile_shape[1] @@ -979,7 +980,7 @@ class ConvOperation3x: kernel_schedule = KernelScheduleSuffixes[self.kernel_schedule] epilogue_schedule = EpilogueScheduleSuffixes[self.epilogue_schedule] - return f"{prefix}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}" + return f"{prefix}_sm{arch}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}" def procedural_name(self): return self.configuration_name() diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index a347091c..710cad31 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -250,6 +250,12 @@ ComplexTransformTag = { ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', } +# Used for cutlass3x complex kernel collective mainloop builder instantiation +ComplexTransformTag3x = { + ComplexTransform.none: 'cute::identity', + ComplexTransform.conj: 'cute::conjugate', +} + # RealComplexBijection = [ (DataType.f16, DataType.cf16), diff --git a/test/python/cutlass/emit/pytorch.py b/test/python/cutlass/emit/pytorch.py index ac75dbb5..18388a76 100644 --- a/test/python/cutlass/emit/pytorch.py +++ b/test/python/cutlass/emit/pytorch.py @@ -124,7 +124,6 @@ class PyTorchExtensionTest(unittest.TestCase): dtype = torch.float16 plan = cutlass.op.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor) - plan.activation = cutlass.epilogue.relu op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: @@ -132,7 +131,7 @@ class PyTorchExtensionTest(unittest.TestCase): A, B, C, _ = _initialize(dtype, 1024, 256, 512) - D_ref = torch.nn.functional.relu(A @ B) + D_ref = A @ B D = mod.run(A, B) assert torch.allclose(D, D_ref) @@ -147,7 +146,7 @@ class PyTorchExtensionTest(unittest.TestCase): alpha = 2.0 beta = -1.0 - D_ref = torch.nn.functional.relu((A @ B) * alpha + (beta * C)) + D_ref = (A @ B) * alpha + (beta * C) D = mod.run(A, B, C, alpha, beta) assert torch.allclose(D, D_ref) diff --git a/test/unit/conv/cache_testbed_output.h b/test/unit/conv/cache_testbed_output.h index 8c443022..4f3981e8 100644 --- a/test/unit/conv/cache_testbed_output.h +++ b/test/unit/conv/cache_testbed_output.h @@ -122,19 +122,15 @@ inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) struct CachedTestResult { uint32_t D; - uint32_t sum; - uint32_t sum_of_square; - uint32_t second_sum_of_square; // // Methods // - CachedTestResult(): D(), sum(), sum_of_square(), second_sum_of_square() { } + CachedTestResult(): D() + { } - CachedTestResult(uint32_t D): D(D), sum(), sum_of_square(), second_sum_of_square() { } - - CachedTestResult(uint32_t D, uint32_t sum, uint32_t sum_of_square, uint32_t second_sum_of_square): - D(D), sum(sum), sum_of_square(sum_of_square), second_sum_of_square(second_sum_of_square) { } + CachedTestResult(uint32_t D): D(D) + { } operator bool() const { return bool(D); @@ -262,6 +258,7 @@ inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { case cutlass::conv::Operator::kFprop: return "fprop"; case cutlass::conv::Operator::kDgrad: return "dgrad"; case cutlass::conv::Operator::kWgrad: return "wgrad"; + case cutlass::conv::Operator::kDeconv: return "deconv"; } return "conv_unknown"; } diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index 0ac101ec..d3a6782f 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -140,14 +140,19 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu + deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu conv2d_fprop_with_broadcast_simt_sm80.cu + deconv2d_with_broadcast_simt_sm80.cu conv3d_fprop_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu conv3d_dgrad_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu conv3d_wgrad_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu + deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu conv3d_fprop_with_broadcast_simt_sm80.cu + deconv3d_with_broadcast_simt_sm80.cu + ) endif() @@ -176,6 +181,7 @@ cutlass_test_unit_add_executable( conv2d_fprop_with_broadcast_sm75.cu conv2d_fprop_with_reduction_sm75.cu + conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu ) @@ -209,6 +215,7 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) conv2d_strided_dgrad_implicit_gemm_swizzling4_sm80.cu # Conv3d + conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu # Group Conv2d diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu index 240c579a..c75ebbee 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -85,7 +85,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens } //////////////////////////////////////////////////////////////////////////////// -#if 0 + TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128_64x3_64x64x64) { @@ -116,7 +116,8 @@ TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_t cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; @@ -124,7 +125,6 @@ TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_t /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d()); } -#endif //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu index 4d0cb2f2..944af8b6 100644 --- a/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu @@ -81,7 +81,8 @@ TEST(SM80_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f32nhwc_f32nh cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; @@ -103,7 +104,7 @@ template < template class UnaryOp, bool TestSplitK = true > -void TestResidaulBlock() { +static void Conv2dFpropSM80TestResidaulBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -161,7 +162,7 @@ void TestResidaulBlock() { TEST(SM80_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - TestResidaulBlock(); + Conv2dFpropSM80TestResidaulBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 2ace470b..d957beb0 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -153,7 +153,6 @@ public: else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } else if (dist_kind == cutlass::Distribution::Gaussian) { @@ -489,7 +488,8 @@ public: fname << "error_Conv2d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << ss_problem_size_text.str() << Conv2d::ThreadblockShape::kM << "x" << Conv2d::ThreadblockShape::kN << "x" @@ -635,8 +635,8 @@ bool TestAllConv2d( // // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -663,8 +663,8 @@ bool TestAllConv2d( // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} // Although strided dgrad works for all stride combinations, we are only going // to run strided dgrad for non-unity strides - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -718,8 +718,8 @@ bool TestAllConv2d( } // CUTLASS DGRAD's *strided* specialization does not support split-k mode - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index 52771a7a..278d447f 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -404,9 +404,9 @@ public: // compute tensor Z and tensor T for (int n = 0; n < problem_size.N; ++n) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - for (int k = 0; k < problem_size.K; ++k) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { ElementZ z{}; ElementT t{}; @@ -449,7 +449,8 @@ public: fname << "error_Conv2d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << "nhwc_" << problem_size.N << "x" << problem_size.H << "x" @@ -602,8 +603,8 @@ bool TestAllConv2dWithBroadcast( // // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -613,8 +614,8 @@ bool TestAllConv2dWithBroadcast( #if 0 // relax restrictions on analytic strided dgrad // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -650,8 +651,8 @@ bool TestAllConv2dWithBroadcast( } // CUTLASS DGRAD's *strided* specialization does not support split-k mode - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { diff --git a/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu index 910e92e4..27ae274c 100644 --- a/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu @@ -111,7 +111,8 @@ TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; diff --git a/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu b/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu index e401d113..bc0dee0e 100644 --- a/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu @@ -81,7 +81,8 @@ TEST(SM80_Device_Conv3d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f32ndhwc_f32n cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; @@ -103,7 +104,7 @@ template < template class UnaryOp, bool TestSplitK = true > -void TestResidaulBlock() { +static void Conv3dFpropSM80TestResidaulBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -161,7 +162,7 @@ void TestResidaulBlock() { TEST(SM80_Device_Conv3d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - TestResidaulBlock(); + Conv3dFpropSM80TestResidaulBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv3d_testbed.h b/test/unit/conv/device/conv3d_testbed.h index bfbe8921..54bf9363 100644 --- a/test/unit/conv/device/conv3d_testbed.h +++ b/test/unit/conv/device/conv3d_testbed.h @@ -169,7 +169,7 @@ public: tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); initialize_tensor(tensor_C.host_view(), init_C, seed * 39); tensor_A.sync_device(); @@ -358,12 +358,12 @@ public: bool cached_result_loaded = false; CachedTestResult cached_test_result; - std::string conv2d_result_cache_name = + std::string conv3d_result_cache_name = std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - CachedTestResultListing cached_results(conv2d_result_cache_name); + CachedTestResultListing cached_results(conv3d_result_cache_name); auto cached = cached_results.find(cached_test_key); @@ -376,7 +376,7 @@ public: if (!cached_result_loaded) { #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - + cutlass::reference::device::Conv3d< ElementA, LayoutA, @@ -426,15 +426,14 @@ public: cached_test_result.D = TensorHash(tensor_D_reference.host_view()); - CachedTestResultListing cached_results(conv2d_result_cache_name); + CachedTestResultListing cached_results(conv3d_result_cache_name); cached_results.append(cached_test_key, cached_test_result); - cached_results.write(conv2d_result_cache_name); + cached_results.write(conv3d_result_cache_name); } } // if (!cached_result_loaded) uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { passed = (tensor_D_hash == cached_test_result.D); @@ -456,7 +455,8 @@ public: fname << "error_Conv3d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << "ndhwc_" << problem_size.N << "x" << problem_size.D << "x" @@ -571,8 +571,8 @@ bool TestAllConv3d( // // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity) || (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == diff --git a/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/test/unit/conv/device/conv3d_with_broadcast_testbed.h index 93acfcab..cc7c06f7 100644 --- a/test/unit/conv/device/conv3d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -227,7 +227,6 @@ public: initialize_tensor(tensor_B.host_view(), init_B, seed * 17); initialize_tensor(tensor_C.host_view(), init_C, seed * 39); initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); - for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { for (int o = 0; o < tensor_C_reference.extent().d(); ++o) { for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { @@ -239,7 +238,6 @@ public: } } } - tensor_A.sync_device(); tensor_B.sync_device(); tensor_C.sync_device(); @@ -407,10 +405,10 @@ public: // compute tensor Z and tensor T for (int n = 0; n < problem_size.N; ++n) { - for (int o = 0; o < problem_size.Z; ++o) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - for (int k = 0; k < problem_size.K; ++k) { + for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { ElementZ z{}; ElementT t{}; @@ -454,7 +452,8 @@ public: fname << "error_Conv3d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << "nnhwc_" << problem_size.N << "x" << problem_size.D << "x" @@ -563,8 +562,8 @@ bool TestAllConv3dWithBroadcast( // // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_d == 1) && @@ -577,8 +576,8 @@ bool TestAllConv3dWithBroadcast( #if 0 // relax restrictions on analytic strided dgrad // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { diff --git a/test/unit/conv/device/deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu new file mode 100644 index 00000000..73a78d33 --- /dev/null +++ b/test/unit/conv/device/deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_deconv2d.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv2d_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_8x4_32x64x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv2d instance + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2d< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_8x4_64x32x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv2d instance + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2d< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu b/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu new file mode 100644 index 00000000..7872f8a4 --- /dev/null +++ b/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/conv/kernel/default_deconv2d_with_broadcast.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_with_broadcast_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + + +TEST(SM80_Device_Deconv2d_With_Broadcast_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC, + ElementAccumulator, + ElementCompute, + ElementC, + ElementC, + 1, + cutlass::epilogue::thread::ReLu + >; + + /// Device-level Conv2d instance + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2dWithBroadcast< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast()); +} + +// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) +// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. +// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, +// which only the last thread block would have an access to, before applying BinaryOp. +// The epilogue functor in the last thread block would have to be given three inputs, namely +// partial outputs, bias, and residual, but this is not supported in the current interface. +// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. +template < + template class ActivationOp, + template class BinaryOp, + template class UnaryOp, + bool TestSplitK = true +> +static void Deconv2dSM80TestResidaulBlock() { + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = ElementC; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementD, + ElementAccumulator, + ElementCompute, + ElementC, + 1, + ActivationOp, + BinaryOp, + UnaryOp + >; + + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2dWithBroadcast< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + struct ReferenceOp { + using OutputOp = typename Deconv2d::EpilogueOutputOp; + using ElementZ = typename OutputOp::ElementZ; + + ActivationOp activation; + BinaryOp binary_op; + UnaryOp unary_op; + + void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { + Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); + } + }; + + bool passed = test::conv::device::TestAllConv2dWithBroadcast(); + EXPECT_TRUE(passed); +} + +TEST(SM80_Device_Deconv2d_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_8x4_32x64x8) { + // Resnet + Deconv2dSM80TestResidaulBlock(); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu b/test/unit/conv/device/deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu new file mode 100644 index 00000000..929a5151 --- /dev/null +++ b/test/unit/conv/device/deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv3d.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv3d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv3d_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_8x4_32x64x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv3d instance + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3d< + ElementA, + cutlass::layout::TensorNDHWC, + ElementB, + cutlass::layout::TensorNDHWC, + ElementC, + cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv3d instance + EXPECT_TRUE(test::conv::device::TestAllConv3d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv3d_Optimized_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_8x4_64x32x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv3d instance + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3d< + ElementA, + cutlass::layout::TensorNDHWC, + ElementB, + cutlass::layout::TensorNDHWC, + ElementC, + cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv3d instance + EXPECT_TRUE(test::conv::device::TestAllConv3d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu b/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu new file mode 100644 index 00000000..e0d0171f --- /dev/null +++ b/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/conv/kernel/default_deconv3d_with_broadcast.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv3d_with_broadcast_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +TEST(SM80_Device_Deconv3d_With_Broadcast_Optimized_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC, + ElementAccumulator, + ElementCompute, + ElementC, + ElementC, + 1, + cutlass::epilogue::thread::ReLu + >; + + /// Device-level Conv3d instance + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3dWithBroadcast< + ElementA, cutlass::layout::TensorNDHWC, + ElementB, cutlass::layout::TensorNDHWC, + ElementC, cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv3d instance + EXPECT_TRUE(test::conv::device::TestAllConv3dWithBroadcast()); +} + +// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv3d(X) + bias), residual)) +// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. +// This is because the activation needs to be applied to the fully accumulated output of the Conv3d op, +// which only the last thread block would have an access to, before applying BinaryOp. +// The epilogue functor in the last thread block would have to be given three inputs, namely +// partial outputs, bias, and residual, but this is not supported in the current interface. +// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. +template < + template class ActivationOp, + template class BinaryOp, + template class UnaryOp, + bool TestSplitK = true +> +static void Deconv3dSM80TestResidaulBlock() { + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = ElementC; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementD, + ElementAccumulator, + ElementCompute, + ElementC, + 1, + ActivationOp, + BinaryOp, + UnaryOp + >; + + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3dWithBroadcast< + ElementA, cutlass::layout::TensorNDHWC, + ElementB, cutlass::layout::TensorNDHWC, + ElementC, cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + struct ReferenceOp { + using OutputOp = typename Deconv3d::EpilogueOutputOp; + using ElementZ = typename OutputOp::ElementZ; + + ActivationOp activation; + BinaryOp binary_op; + UnaryOp unary_op; + + void operator()(ElementZ &Z, ElementZ&, ElementCompute conv3d, ElementCompute residual) { + Z = ElementZ(unary_op(binary_op(activation(conv3d), residual))); + } + }; + + bool passed = test::conv::device::TestAllConv3dWithBroadcast(); + EXPECT_TRUE(passed); +} + +TEST(SM80_Device_Deconv3d_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_8x4_32x64x8) { + // Resnet + Deconv3dSM80TestResidaulBlock(); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device_3x/CMakeLists.txt b/test/unit/conv/device_3x/CMakeLists.txt index d6d8e215..dddeba6f 100644 --- a/test/unit/conv/device_3x/CMakeLists.txt +++ b/test/unit/conv/device_3x/CMakeLists.txt @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + add_subdirectory(fprop) add_subdirectory(wgrad) add_subdirectory(dgrad) diff --git a/test/unit/conv/device_3x/testbed_conv.hpp b/test/unit/conv/device_3x/testbed_conv.hpp index 67501ea1..e22c2cff 100644 --- a/test/unit/conv/device_3x/testbed_conv.hpp +++ b/test/unit/conv/device_3x/testbed_conv.hpp @@ -53,7 +53,6 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" - #include "conv_problem_sizes.hpp" #include "../cache_testbed_output.h" @@ -195,7 +194,8 @@ struct ConvTestbed { bool run( ProblemShape const& problem_shape, ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0)) { + ElementScalar beta = ElementScalar(0) + ) { // Waive test if insufficient CUDA device if (!sufficient()) { @@ -250,14 +250,16 @@ struct ConvTestbed { auto &fusion_args = args.epilogue.thread; - // some fused patterns have no linear combination + fusion_args.alpha = alpha; + fusion_args.beta = beta; + if constexpr (IsBiasEnabled) { fusion_args.bias_ptr = tensor_bias.data().get(); } // Clamp bound if constexpr (cute::is_same_v>) { - fusion_args.activation.lower_bound = ElementCompute{0}; + fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); } @@ -422,17 +424,11 @@ struct ConvTestbed { reference_impl.compute_reference(); } // Validate kernel against reference - passed = compare_reference( - mD_ref, mD_computed, mA, mB, mAlpha, - mBeta, mBias, - this->epsilon); + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); } #else // Validate kernel against reference - passed = compare_reference( - mD_ref, mD_computed, mA, mB, mAlpha, - mBeta, mBias, - this->epsilon); + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); #endif EXPECT_TRUE(passed); @@ -445,8 +441,7 @@ struct ConvTestbed { class EngineB, class LayoutB, class EngineAlpha, class LayoutAlpha, class EngineBeta, class LayoutBeta, - class EngineBias, class LayoutBias - > + class EngineBias, class LayoutBias> static constexpr bool compare_reference( cute::Tensor const& reference, @@ -503,7 +498,6 @@ struct ConvTestbed { printf("[%ld]: bias = %f\n", i, float(tensor_bias(i))); } } - for (size_t i = 0; i < size_t(size(reference)); ++i) { printf("[%ld]: ref = %f, computed = %f\n", i, float(reference(i)), float(computed(i))); } diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index 40603649..d0e10a7b 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -45,30 +45,3 @@ cutlass_test_unit_add_executable( fast_numeric_conversion.cu functional.cu ) - -# -# CUTLASS 3x increases the host compiler requirements to C++17. However, there are -# certain existing integrations that will benefit from maintaining C++11 compatibility. -# -# This requirement only applies to select .h files which are explicitly annotated. It -# does not apply to any .hpp file. -# -# `cutlass_test_unit_core_cpp11` enforces the C++11 requirement. -# - -set(CMAKE_CUDA_STANDARD 11) -set(CMAKE_CUDA_STANDARD_REQUIRED ON) - -add_executable( - cutlass_test_unit_core_cpp11 - - cpp11.cu -) - -if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - target_compile_options( - cutlass_test_unit_core_cpp11 - PRIVATE - $<$:-Xcompiler -Werror> - ) -endif() diff --git a/test/unit/core/cpp11.cu b/test/unit/core/cpp11.cu deleted file mode 100644 index 553b031c..00000000 --- a/test/unit/core/cpp11.cu +++ /dev/null @@ -1,87 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if (201700L <= __cplusplus ) -#error "This file and all of its includes must be compilable as C++11." -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -int main() { - return 0; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index dd72b840..75e12bdf 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -108,8 +108,8 @@ __global__ void convert_with_scale_factor( ///////////////////////////////////////////////////////////////////////////////////////////////// -template -void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[]) { +template +void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[], const int range = 4, const int offset = 0) { const int kN = Count; dim3 grid(1, 1); @@ -124,7 +124,7 @@ void run_test_with_scalefactor(const char dest_name[], const char source_name[], for (int i = 0; i < kN; ++i) { - source_ref.at({0, i}) = Source(i % Range); + source_ref.at({0, i}) = Source(i % range + offset); } for (int i = 0; i < kN; ++i) { @@ -144,10 +144,12 @@ void run_test_with_scalefactor(const char dest_name[], const char source_name[], for (int i = 0; i < kN; ++i) { float ref = float(source_ref.at({0, i})) / float(scale_factor_ref.at({0, i})); - EXPECT_TRUE(float(destination_ref.at({0, i})) == ref) - << "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i})) - << ", Source type: " << source_name << " " << float(source_ref.at({0, i})) - << ", Count: " << Count; + bool pass = float(destination_ref.at({0, i})) == ref; + EXPECT_TRUE(pass) + << "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i})) << std::endl + << ", Source type: " << source_name << " " << float(source_ref.at({0, i})) << std::endl + << ", Scalefactor type: " << source_name << " " << float(scale_factor_ref.at({0, i})) << std::endl + << ", idx: " << i << std::endl; } } diff --git a/test/unit/cute/CMakeLists.txt b/test/unit/cute/CMakeLists.txt index d001f18e..601c0c0d 100644 --- a/test/unit/cute/CMakeLists.txt +++ b/test/unit/cute/CMakeLists.txt @@ -28,6 +28,7 @@ add_subdirectory(core) add_subdirectory(volta) +add_subdirectory(turing) add_subdirectory(ampere) add_subdirectory(hopper) add_subdirectory(layout) @@ -39,6 +40,7 @@ add_custom_target( cutlass_test_unit_cute_layout cutlass_test_unit_cute_core cutlass_test_unit_cute_volta + cutlass_test_unit_cute_turing cutlass_test_unit_cute_ampere cutlass_test_unit_cute_hopper cutlass_test_unit_cute_msvc_compilation @@ -51,6 +53,7 @@ add_custom_target( test_unit_cute_core test_unit_cute_volta test_unit_cute_ampere + test_unit_cute_turing test_unit_cute_hopper test_unit_cute_msvc_compilation ) diff --git a/test/unit/cute/ampere/CMakeLists.txt b/test/unit/cute/ampere/CMakeLists.txt index d05a73c0..fd701de6 100644 --- a/test/unit/cute/ampere/CMakeLists.txt +++ b/test/unit/cute/ampere/CMakeLists.txt @@ -30,6 +30,7 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_ampere cp_async.cu ldsm.cu + cooperative_gemm.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu new file mode 100644 index 00000000..2fcd0120 --- /dev/null +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include "../cooperative_gemm_common.hpp" + +using namespace cute; + +TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA) { + using value_type = double; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm3_Half_MMA_CustomSmemLayouts) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 128; + constexpr uint32_t k = 128; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` + >; + + using smem_a_atom_layout_t = Layout, Stride< _1,_64>>; + using smem_b_atom_layout_t = Layout, Stride<_32, _1>>; + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm4_Half_MMA_SwizzledSmemLayouts) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 128; + constexpr uint32_t k = 128; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` + >; + + // RowMajor + using smem_rowmajor_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + // ColMajor + using smem_colmajor_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using smem_a_atom_layout_t = smem_rowmajor_atom_layout_t; + using smem_b_atom_layout_t = smem_colmajor_atom_layout_t; + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + + using smem_a_atom_layout_t = smem_a_atom_layout_t; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + + using smem_b_atom_layout_t = smem_b_atom_layout_t; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + + using smem_c_atom_layout_t = smem_c_atom_layout_t; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm5_Double_MMA_SwizzledSmemLayouts) { + using value_type = double; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 64; + constexpr uint32_t k = 16; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA, // Atom + Layout>, // Atom layout + Tile, Stride<_2, _1>>, // 32x32x4 MMA with perm for load vectorization + Layout, Stride<_2, _1>>, + Underscore>>; + + using smem_a_atom_layout_t = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using smem_b_atom_layout_t = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + + using smem_a_atom_layout_t = smem_a_atom_layout_t; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + using smem_b_atom_layout_t = smem_b_atom_layout_t; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + using smem_c_atom_layout_t = smem_c_atom_layout_t; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<128>, // B + AutoVectorizingCopyWithAssumedAlignment<128>, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm6_MixedPrecisionFP16FP32_MMA) { + using TA = cutlass::half_t; + using TB = cutlass::half_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm7_MixedPrecisionBF16FP32_MMA) { + using TA = cutlass::bfloat16_t; + using TB = cutlass::bfloat16_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_MMA) { + using TA = cutlass::tfloat32_t; + using TB = cutlass::tfloat32_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} diff --git a/test/unit/cute/cooperative_gemm_common.hpp b/test/unit/cute/cooperative_gemm_common.hpp new file mode 100644 index 00000000..9f7f6946 --- /dev/null +++ b/test/unit/cute/cooperative_gemm_common.hpp @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel(TA const* a, + TB const* b, + TC* c, + TC* c_out, + Alpha const alpha, + Beta const beta, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform) +{ + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), ALayout{}); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), BLayout{}); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), CLayout{}); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), CLayout{}); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(SMemALayout {})), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(SMemBLayout {})), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), SMemALayout{}); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), SMemBLayout{}); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), SMemCLayout{}); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + TiledMma tiled_mma; + cooperative_gemm( + threadIdx.x, tiled_mma, + alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, + a_load_transform, b_load_transform, c_load_transform, c_store_transform + ); + __syncthreads(); + + cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); +} + +template +void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + using gmem_a_layout_t = ALayout; + using gmem_b_layout_t = BLayout; + using gmem_c_layout_t = CLayout; + + using smem_a_layout_t = SMemALayout; + using smem_b_layout_t = SMemBLayout; + using smem_c_layout_t = SMemCLayout; + + static_assert(size<0>(gmem_a_layout_t{}) == size<0>(gmem_c_layout_t{})); // AM == CM + static_assert(size<0>(gmem_b_layout_t{}) == size<1>(gmem_c_layout_t{})); // BN == CN + static_assert(size<1>(gmem_a_layout_t{}) == size<1>(gmem_b_layout_t{})); // AK == BK + + static_assert(size<0>(smem_a_layout_t{}) == size<0>(smem_c_layout_t{})); // AM == CM + static_assert(size<0>(smem_b_layout_t{}) == size<1>(smem_c_layout_t{})); // BN == CN + static_assert(size<1>(smem_a_layout_t{}) == size<1>(smem_b_layout_t{})); // AK == BK + + static_assert(cute::size(gmem_a_layout_t {}) == cute::size(smem_a_layout_t {})); + static_assert(cute::size(gmem_b_layout_t {}) == cute::size(smem_b_layout_t {})); + static_assert(cute::size(gmem_c_layout_t {}) == cute::size(smem_c_layout_t {})); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout_t{}); print("\n"); + print(" "); print("smem: "); print(smem_layout_t{}); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.1); + const auto beta = static_cast(1.2); + + thrust::host_vector h_a(cosize(gmem_a_layout_t{})); + thrust::host_vector h_b(cosize(gmem_b_layout_t{})); + thrust::host_vector h_c(cosize(gmem_c_layout_t{})); + thrust::host_vector h_c_out(cosize(gmem_c_layout_t{})); + + auto h_a_tensor = make_tensor(h_a.data(), gmem_a_layout_t{}); + auto h_b_tensor = make_tensor(h_b.data(), gmem_b_layout_t{}); + auto h_c_tensor = make_tensor(h_c.data(), gmem_c_layout_t{}); + size_t max_size = std::max({static_cast(size(gmem_a_layout_t {})), + static_cast(size(gmem_b_layout_t {})), + static_cast(size(gmem_c_layout_t {}))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(gmem_a_layout_t{})) { + h_a_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); + } + if(i < size(gmem_b_layout_t{})) { + h_b_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); + } + if(i < size(gmem_c_layout_t{})) { + h_c_tensor(i) = static_cast((di*di) / size(gmem_a_layout_t{})); + } + } + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); + + const size_t shared_memory_size = + (sizeof(TA) * h_a.size()) + (sizeof(TB) * h_b.size()) + (sizeof(TC) * h_c.size()); + auto kernel = cooperative_gemm_kernel< + gmem_a_layout_t, gmem_b_layout_t, gmem_c_layout_t, + smem_a_layout_t, smem_b_layout_t, smem_c_layout_t, + SmemCopyOpA, SmemCopyOpB, SmemCopyOpC, + ThreadBlockSize, TiledMma, CopyMaxVecBits, + TA, TB, TC, decltype(alpha), decltype(beta), + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform + >; + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + alpha, + beta, + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform + ); + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + thrust::host_vector h_c_ref(h_c.size(), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), gmem_c_layout_t{}); + // A * B + for (int k = 0; k < size<1>(h_a_tensor); k++) { + for (int m = 0; m < size<0>(h_a_tensor); m++) { + for (int n = 0; n < size<0>(h_b_tensor); n++) { + const auto a_value = a_load_transform(h_a_tensor(m, k)); + const auto b_value = b_load_transform(h_b_tensor(n, k)); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); + h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); + } + } + } + // C = A*B + C + for (int i = 0; i < size(h_c_ref_tensor); i++) { + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + } + + h_c_out = d_c_out; + auto h_c_out_tensor = make_tensor(h_c_out.data(), gmem_c_layout_t{}); + for (int i = 0; i < size(h_c_ref_tensor); i++) { + double h_c_ref_i = h_c_ref_tensor(i); + double h_c_out_i = h_c_out_tensor(i); + double epsilon(0.1f); + double nonzero_floor(std::numeric_limits::min()); + bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); + ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; + } +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + using smem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + using smem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using smem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + test_cooperative_gemm>, + AutoVectorizingCopyWithAssumedAlignment>, + AutoVectorizingCopyWithAssumedAlignment>, + ThreadBlockSize, + TiledMMAType, + CopyMaxVecBits, + TA, + TB, + TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + test_cooperative_gemm_col_major_layout, T, T, T>( + a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + using smem_a_atom_layout_t = SMemAAtomLayout; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + + using smem_b_atom_layout_t = SMemBAtomLayout; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + + using smem_c_atom_layout_t = SMemCAtomLayout; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm>, + AutoVectorizingCopyWithAssumedAlignment>, + AutoVectorizingCopyWithAssumedAlignment>, + ThreadBlockSize, + TiledMMAType, + CopyMaxVecBits, + TA, + TB, + TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + test_cooperative_gemm_col_major_layout, + T, + T, + T>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} diff --git a/test/unit/cute/core/complement.cpp b/test/unit/cute/core/complement.cpp index 460fdede..cba486f6 100644 --- a/test/unit/cute/core/complement.cpp +++ b/test/unit/cute/core/complement.cpp @@ -35,22 +35,22 @@ #include -template +template void -test_complement(Layout const& layout, CoSizeHi const& cosize_hi) +test_complement(Layout const& layout, CoTarget const& cotarget) { using namespace cute; - auto result = complement(layout, cosize_hi); + auto result = complement(layout, cotarget); - CUTLASS_TRACE_HOST("complement(" << layout << ", " << cosize_hi << ") => " << result); + CUTLASS_TRACE_HOST("complement(" << layout << ", " << cotarget << ") => " << result); auto completed = make_layout(layout, result); // Lower-bound on the codomain size of the layout ++ complement (1) - EXPECT_GE(cosize(completed), cosize_hi); + EXPECT_GE(cosize(completed), size(cotarget)); // Upper-bound on the codomain size of the complement (2) - EXPECT_LE(cosize(result), cute::round_up(cosize_hi, cosize(layout))); + EXPECT_LE(cosize(result), cute::round_up(size(cotarget), cosize(layout))); // Post-condition on the codomain of the complement for (int i = 1; i < size(result); ++i) { @@ -62,9 +62,9 @@ test_complement(Layout const& layout, CoSizeHi const& cosize_hi) // Other observations EXPECT_LE(size(result), cosize(result)); // As a result of the ordered condition (3) - EXPECT_GE(size(result), cosize_hi / size(filter(layout))); + EXPECT_GE(size(result), size(cotarget) / size(filter(layout))); EXPECT_LE(cosize(completed), cosize(result) + cosize(layout)); - EXPECT_GE(cosize(result), cosize_hi / size(filter(layout))); + EXPECT_GE(cosize(result), size(cotarget) / size(filter(layout))); if constexpr (is_static::value) { // If we can apply complement again EXPECT_EQ(size(complement(completed)), 1); // There's no more codomain left over } @@ -90,6 +90,8 @@ TEST(CuTe_core, Complement) test_complement(layout); test_complement(layout, Int<2>{}); + test_complement(layout, Int<5>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -97,6 +99,8 @@ TEST(CuTe_core, Complement) test_complement(layout); test_complement(layout, Int<2>{}); + test_complement(layout, Int<5>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -105,6 +109,8 @@ TEST(CuTe_core, Complement) test_complement(layout, Int<1>{}); test_complement(layout, Int<2>{}); test_complement(layout, Int<8>{}); + test_complement(layout, Int<5>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -130,6 +136,7 @@ TEST(CuTe_core, Complement) test_complement(layout); test_complement(layout, Int<16>{}); test_complement(layout, Int<19>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -138,6 +145,7 @@ TEST(CuTe_core, Complement) test_complement(layout, Int<1>{}); test_complement(layout); test_complement(layout, Int<17>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -193,8 +201,8 @@ TEST(CuTe_core, Complement) // Fails due to non-injective layout // { - // auto layout = make_layout(Shape,Shape<_2, _2>>{}, - // Stride,Stride<_8,_4>>{}); + // auto layout = make_layout(Shape ,Shape <_2,_2>>{}, + // Stride,Stride<_8,_4>>{}); // test_complement(layout); // } @@ -289,4 +297,11 @@ TEST(CuTe_core, Complement) test_complement(layout); } + + { + auto layout = make_layout(Int<64>{}); + + test_complement(layout, make_shape(Int<32>{}, Int<4>{}, Int<4>{})); + test_complement(layout, make_shape(Int<32>{}, Int<4>{}, 4)); + } } diff --git a/test/unit/cute/core/composition.cpp b/test/unit/cute/core/composition.cpp index 8f50ba5e..8e043f89 100644 --- a/test/unit/cute/core/composition.cpp +++ b/test/unit/cute/core/composition.cpp @@ -212,13 +212,12 @@ TEST(CuTe_core, Composition) test_composition(a, b); } - // FAILS due to b not "dividing into" a properly - //{ - // auto a = make_layout(Shape<_4,_3>{}); - // auto b = make_layout(Shape<_6>{}); + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_6>{}); - // test_composition(a, b); - //} + test_composition(a, b); + } { auto a = make_layout(Shape<_4,_3>{}); @@ -234,13 +233,12 @@ TEST(CuTe_core, Composition) test_composition(a, b); } - // FAILS due to b not "dividing into" a properly - //{ - // auto a = make_layout(Shape<_4,_3>{}); - // auto b = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); - // test_composition(a, b); - //} + test_composition(a, b); + } { auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); @@ -523,4 +521,21 @@ TEST(CuTe_core, Composition) test_composition(a, b); } + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("BETA: Tuple strides" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = make_layout(Shape<_4,_4>{}, Stride<_4,_1>{}); + auto b = make_layout(Shape<_4,_4>{}, Stride,E<0>>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,Shape<_2,_3>>{}, Stride<_6,Stride<_3,_1>>{}); + auto b = make_layout(Shape<_2,_4>{}, Stride,E<0>>{}); + + test_composition(a, b); + } } diff --git a/test/unit/cute/core/logical_divide.cpp b/test/unit/cute/core/logical_divide.cpp index 6c1e2f29..061fd548 100644 --- a/test/unit/cute/core/logical_divide.cpp +++ b/test/unit/cute/core/logical_divide.cpp @@ -227,27 +227,42 @@ TEST(CuTe_core, Logical_divide) ASSERT_TRUE(decltype(stride<1>(result) == Int<48>{})::value); } - // DISALLOWED - //{ - //auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); - //auto tile = Layout<_32>{}; + { + auto layout = make_layout(make_shape(Int<32>{}, Int<4>{}, 4)); + auto tile = Layout<_64>{}; - //test_logical_divide(layout, tile); - //} + test_logical_divide(layout, tile); - //{ - //auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); - //auto tile = Layout<_32,_2>{}; + // Enforcement of result + auto result = logical_divide(layout, tile); + ASSERT_TRUE(bool( shape(result) == make_shape (_64{}, make_shape ( _2{}, 4)))); + ASSERT_TRUE(bool(stride(result) == make_stride( _1{}, make_stride(_64{},_128{})))); + } - //CUTLASS_TRACE_HOST("complement: " << complement(tile, size(layout))); - //test_logical_divide(layout, tile); - //} - //{ - //auto layout = make_layout(make_shape(16,4,3), make_stride(1,512,0)); - //auto tile = Layout<_32>{}; + // + // ALLOWED, but dangerous due to the dynamic lhs shapes + // Consider disallowing... + // - //CUTLASS_TRACE_HOST("complement: " << complement(tile, size(layout))); - //test_logical_divide(layout, tile); - //} + { + auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); + auto tile = Layout<_32>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); + auto tile = Layout<_32,_2>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = make_layout(make_shape(16,4,3), make_stride(1,512,0)); + auto tile = Layout<_32>{}; + + test_logical_divide(layout, tile); + } } diff --git a/test/unit/cute/hopper/CMakeLists.txt b/test/unit/cute/hopper/CMakeLists.txt index f05d86b9..0b6db66f 100644 --- a/test/unit/cute/hopper/CMakeLists.txt +++ b/test/unit/cute/hopper/CMakeLists.txt @@ -56,6 +56,11 @@ cutlass_test_unit_add_executable( tma_load.cu ) +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_tma_mcast_load + tma_mcast_load.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_cute_hopper_tma_store tma_store.cu diff --git a/test/unit/cute/hopper/tma_load.cu b/test/unit/cute/hopper/tma_load.cu index d171d36e..0105d351 100644 --- a/test/unit/cute/hopper/tma_load.cu +++ b/test/unit/cute/hopper/tma_load.cu @@ -44,7 +44,6 @@ test_tma_load(GMEM_Layout const& gmem_layout, SMEM_Layout const& smem_layout, CTA_Tile const& cta_tile) { - using namespace cute; return test_tma_load(SM90_TMA_LOAD{}, gmem_layout, smem_layout, cta_tile); } @@ -53,7 +52,6 @@ auto test_tma_load(GMEM_Layout const& gmem_layout, SMEM_Layout const& smem_layout) { - using namespace cute; return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); } diff --git a/test/unit/cute/hopper/tma_mcast_load.cu b/test/unit/cute/hopper/tma_mcast_load.cu new file mode 100644 index 00000000..9a330716 --- /dev/null +++ b/test/unit/cute/hopper/tma_mcast_load.cu @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include "../hopper/tma_mcast_load_testbed.hpp" + +using namespace cute; +using namespace cutlass::test; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template > +auto +test_tma_load(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile, + Cluster_Size const& cluster_size = {}) +{ + return test_tma_load(SM90_TMA_LOAD_MULTICAST{}, gmem_layout, smem_layout, cta_tile, cluster_size); +} + +template +auto +test_tma_load(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) +{ + return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); +} + +TEST(SM90_CuTe_Hopper, Tma_Load_32x32_Col_MCast) +{ + Layout smem_layout = Layout, Stride<_1,_32>>{}; + { + Layout gmem_layout = make_layout(make_shape(32,32), GenColMajor{}); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + test_tma_load< float>(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), 2); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), 2); + test_tma_load< float>(gmem_layout, smem_layout, shape(smem_layout), 2); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), 2); + } +} + +#endif diff --git a/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp new file mode 100644 index 00000000..2fb88de5 --- /dev/null +++ b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include +#include +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout smem_layout, + CUTE_GRID_CONSTANT CopyAtom const tma, CTA_Tiler cta_tiler, Cluster_Size cluster_size) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + Tensor gA = zipped_divide(mA, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + Tensor gB = zipped_divide(mB, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + +#if 1 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Prepare the TMA_LOAD + // + + Tensor sA_x = make_tensor(sA.data(), make_layout(sA.layout(), Layout<_1>{})); // ((CTA_TILE_M,CTA_TILE_N,...),_1) + Tensor tBgB = gB; // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto [tAgA, tAsA] = tma_partition(tma, cta_rank_in_cluster, make_layout(cluster_size), sA_x, gA); + +#if 1 + if (thread0()) { + print("sA_x : "); print(sA_x); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // TMA Multicast Masks -- Get a mask of the active ctas in each TMA + // + + + int elected_cta_rank = 0; + bool elect_one_cta = (elected_cta_rank == cta_rank_in_cluster); + bool elect_one_thr = cute::elect_one_sync(); + + uint16_t tma_mcast_mask = ((uint16_t(1) << cluster_size) - 1); + +#if 1 + if (thread0()) { + print("tma_mcast_mask : "); print(tma_mcast_mask); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Perform the TMA_LOAD + // + + if (elect_one_thr) { + // Initialize TMA barrier + cute::initialize_barrier(tma_load_mbar[0], /* num_threads */ 1); + } + int tma_phase_bit = 0; + // Ensures all CTAs in the Cluster have initialized + __syncthreads(); + cute::cluster_sync(); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + + if (elect_one_thr) + { + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0], tma_mcast_mask), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from tma_phase_bit value + cute::wait_barrier(tma_load_mbar[0], tma_phase_bit); + tma_phase_bit ^= 1; + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (elect_one_cta && elect_one_thr) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + cute::cluster_sync(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_atom(copy_op, gA, smem_layout, cta_tiler, cluster_size); + //print(tma); + + // Launch + + dim3 dimBlock(32); + dim3 dimCluster(size(cluster_size)); + dim3 dimGrid = dimCluster; + int smem_size = sizeof(SharedStorage); + + void* kernel_ptr = (void*) &tma_test_device_cute; + + cutlass::launch_kernel_on_cluster({dimGrid, dimBlock, dimCluster, smem_size}, + kernel_ptr, + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast(raw_pointer_cast(d_out.data())), + gmem_layout, + smem_layout, + tma, cta_tiler, cluster_size); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/test/unit/cute/turing/CMakeLists.txt b/test/unit/cute/turing/CMakeLists.txt new file mode 100644 index 00000000..ac8a0487 --- /dev/null +++ b/test/unit/cute/turing/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_turing + cooperative_gemm.cu +) diff --git a/test/unit/cute/turing/cooperative_gemm.cu b/test/unit/cute/turing/cooperative_gemm.cu new file mode 100644 index 00000000..14ea9670 --- /dev/null +++ b/test/unit/cute/turing/cooperative_gemm.cu @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include "../cooperative_gemm_common.hpp" + +using namespace cute; + +TEST(SM75_CuTe_Turing, CooperativeGemm1_MixedPrecisionFP16FP32_MMA) { + using TA = cutlass::half_t; + using TB = cutlass::half_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} diff --git a/test/unit/cute/volta/CMakeLists.txt b/test/unit/cute/volta/CMakeLists.txt index 0777f5bf..d6688aa3 100644 --- a/test/unit/cute/volta/CMakeLists.txt +++ b/test/unit/cute/volta/CMakeLists.txt @@ -30,4 +30,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_volta vectorization_auto.cu cooperative_copy.cu + cooperative_gemm.cu ) diff --git a/test/unit/cute/volta/cooperative_copy.cu b/test/unit/cute/volta/cooperative_copy.cu index c1ffe34c..2fc80b36 100644 --- a/test/unit/cute/volta/cooperative_copy.cu +++ b/test/unit/cute/volta/cooperative_copy.cu @@ -263,6 +263,21 @@ TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefault1D) value_type>(); } +TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefault1DFallback) +{ + using value_type = float; + constexpr uint32_t count = 99; + using gmem_layout_t = decltype(make_layout(make_shape(Int{}))); + using smem_layout_t = decltype(make_layout(make_shape(Int{}))); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(); +} + TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2D) { using value_type = float; @@ -279,6 +294,22 @@ TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2D) value_type>(); } +TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2DFallback) +{ + using value_type = float; + constexpr uint32_t x = 37; + constexpr uint32_t y = 37; + using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(); +} + TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2DCustomStride) { using value_type = float; @@ -312,6 +343,23 @@ TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG3D) value_type>(); } +TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG3DFallback) +{ + using value_type = cute::half_t; + constexpr uint32_t x = 44; + constexpr uint32_t y = 24; + constexpr uint32_t z = 14; + using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); + using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(); +} + TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2Dto3D) { using value_type = double; diff --git a/test/unit/cute/volta/cooperative_gemm.cu b/test/unit/cute/volta/cooperative_gemm.cu new file mode 100644 index 00000000..e8deb8b6 --- /dev/null +++ b/test/unit/cute/volta/cooperative_gemm.cu @@ -0,0 +1,421 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include "../cooperative_gemm_common.hpp" + +using namespace cute; + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA) { + using value_type = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 32; + constexpr uint32_t k = 16; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication) { + using value_type = float; + + constexpr uint32_t m = 88; + constexpr uint32_t n = 20; + constexpr uint32_t k = 12; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication2) { + using value_type = float; + + constexpr uint32_t m = 88; + constexpr uint32_t n = 36; + constexpr uint32_t k = 24; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication3) { + using value_type = float; + + constexpr uint32_t m = 67; + constexpr uint32_t n = 13; + constexpr uint32_t k = 11; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm2_DoubleFMA) { + using value_type = double; + + constexpr uint32_t m = 16; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) { + using value_type = float; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 256; + + using tiled_mma_t = TiledMMA< + MMA_Atom< + UniversalFMA + >, + Layout< + Shape<_16, _16, _1> + >, + Tile< + Layout< + Shape<_16,_2>, Stride<_2,_1> + >, // 32x32x1 MMA with perm for load vectorization + Layout< + Shape<_16,_2>, Stride<_2,_1> + >, + Underscore + > + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm4_Half_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using smem_a_atom_layout_t = typename tiled_mma_t::AtomLayoutB_TV; + using smem_b_atom_layout_t = typename tiled_mma_t::AtomLayoutA_TV; + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<128>, // B + AutoVectorizingCopyWithAssumedAlignment<128>, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA_Predicated) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 31; + constexpr uint32_t n = 27; + constexpr uint32_t k = 17; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<16>, // B + AutoVectorizingCopyWithAssumedAlignment<16>, // C + thread_block_size, + tiled_mma_t, + 16, + value_type, + value_type, + value_type>(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm6_Half_MAA_SwizzledSmemLayouts) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 128; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using smem_a_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using smem_b_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + + using smem_a_atom_layout_t = smem_a_atom_layout_t; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + + // Transposed + using smem_b_atom_layout_t = smem_b_atom_layout_t; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + + using smem_c_atom_layout_t = smem_c_atom_layout_t; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<128>, // B + AutoVectorizingCopyWithAssumedAlignment<128>, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_FMA) { + using TA = float; + using TB = float; + using TC = double; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom>, + Layout> + >; + + auto aload = cute::negate {}; + auto bload = cute::negate {}; + auto cload = cute::negate {}; + auto cstore = cute::negate {}; + + test_cooperative_gemm_col_major_layout( + aload, bload, cload, cstore); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + auto aload = cute::negate {}; + auto bload = cute::negate {}; + auto cload = cute::negate {}; + auto cstore = cute::negate {}; + + test_cooperative_gemm_col_major_layout( + aload, bload, cload, cstore); +} + +template +struct increment_by_x { + ConstantType x; + + template + CUTE_HOST_DEVICE constexpr + T operator()(const T& arg) const { + return arg + x; + } +}; + +template +struct convert_to { + CUTE_HOST_DEVICE constexpr + To operator()(const From& arg) const { + return static_cast(arg); + } +}; + +TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformCustomOp_FMA) { + using TA = float; + using TB = float; + using TC = double; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom>, + Layout> + >; + + auto aload = increment_by_x{1.111f}; + auto bload = convert_to {}; + auto cload = cute::negate {}; + auto cstore = cute::negate {}; + + test_cooperative_gemm_col_major_layout( + aload, bload, cload, cstore); +} diff --git a/test/unit/epilogue/thread/linear_combination_planar_complex.cu b/test/unit/epilogue/thread/linear_combination_planar_complex.cu index c950c12f..6cbc9589 100644 --- a/test/unit/epilogue/thread/linear_combination_planar_complex.cu +++ b/test/unit/epilogue/thread/linear_combination_planar_complex.cu @@ -183,7 +183,7 @@ TEST(Epilogue_thread_linear_combination_planar_complex, f16_f32) { source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); } - cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + cutlass::ArrayPlanarComplex destination{ linear_combination_op(accum, source) }; // Verify each result for (int i = 0; i < kCount; ++i) { diff --git a/test/unit/epilogue/threadblock/testbed.h b/test/unit/epilogue/threadblock/testbed.h index eadda470..b773d27c 100644 --- a/test/unit/epilogue/threadblock/testbed.h +++ b/test/unit/epilogue/threadblock/testbed.h @@ -42,6 +42,7 @@ #include "cutlass/half.h" #include "cutlass/complex.h" #include "cutlass/quaternion.h" +#include "cutlass/platform/platform.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/util/host_tensor.h" @@ -193,15 +194,15 @@ public: cutlass::reference::host::TensorFillRandomUniform( accumulator_tensor.host_view(), seed, - 20, - -20, + 2, + -2, 0); cutlass::reference::host::TensorFillRandomUniform( source_tensor.host_view(), seed + 2018, - 20, - -20, + 2, + -2, 0); } @@ -300,7 +301,9 @@ public: output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + output_params.beta * ElementCompute(source_tensor.at(coord)); - if (std::numeric_limits::is_integer + if ((cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || std::numeric_limits::is_integer) && !std::numeric_limits::is_integer) { std::fesetround(FE_TONEAREST); expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 17a7b3fc..b0225b81 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -1336,7 +1336,6 @@ struct TestbedImpl { { using namespace cute; auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); auto epilogue_params = collective_epilogue.to_host_args(problem_size); diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu index 9c07e72d..56b85846 100644 --- a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu @@ -163,6 +163,50 @@ TEST(SM90_Device_Gemm_f32t_f32t_f32n_tensor_op_gmma_f32, 128x128x32_1x1x1_cooper EXPECT_TRUE(test::gemm::device::TestAll()); } +TEST(SM90_Device_Gemm_f32t_f32t_f32n_tensor_op_gmma_f32, 128x128x32_1x1x1_cooperative_narrow_wgmma) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + // Manually configure a half-tile wide MMA instruction + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<5, Shape<_1,_1,_1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative>, + Shape<_128,_128,_32>, + float, + cutlass::detail::TagToStrideA_t, + float, + cutlass::detail::TagToStrideB_t, + decltype(cute::make_tiled_mma(cute::SM90_64x64x8_F32TF32TF32_SS_TN{}, Layout>{})), + cute::SM90_TMA_LOAD, + cute::GMMA::Layout_K_SW128_Atom, + void, + cute::identity, + cute::SM90_TMA_LOAD, + cute::GMMA::Layout_K_SW128_Atom, + void, + cute::identity + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + /////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu index 743266ac..9cf8f312 100644 --- a/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu @@ -54,6 +54,7 @@ #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) using namespace cute; + /////////////////////////////////////////////////////////////////////////////// //////////////////////////////// output: E4M3 ///////////////////////////////// /////////////////////////////////////////////////////////////////////////////// @@ -760,7 +761,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x4x1_non EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } - +// Use Hopper FP8+AUX from 12.1 +#if (!((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ == 0))) /////////////////////////////////////////////////////////////////////////////// ///////////////////////// output: E4M3 + Aux Tensor /////////////////////////// @@ -808,6 +810,7 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tenso using Gemm = cutlass::gemm::device::GemmUniversalAdapter; EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } +#endif /////////////////////////////////////////////////////////////////////////////// ////////////////////////////////// FP8 Accum ///////////////////////////////// @@ -990,6 +993,10 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_bias_bf16 EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } + +// Use Hopper FP8+AUX from 12.1 +#if (!((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ == 0))) + /////////////////////////////////////////////////////////////////////////////// ///////////////////// output: E4M3 + Aux Tensor + Bias///////////////////////// /////////////////////////////////////////////////////////////////////////////// @@ -1142,6 +1149,8 @@ TEST(SM90_Device_Gemm_e4m3t_e5m2n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tenso EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } +#endif + /////////////////////////////////////////////////////////////////////////////// //////////////////////////////// TMA epilogue ///////////////////////////////// /////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/nvrtc/thread/testbed.h b/test/unit/nvrtc/thread/testbed.h index 7b0b1236..6c59afeb 100644 --- a/test/unit/nvrtc/thread/testbed.h +++ b/test/unit/nvrtc/thread/testbed.h @@ -275,7 +275,7 @@ struct Testbed { nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); const char *opts[] = {"--gpu-architecture=compute_75", - "--std=c++11", + "--std=c++17", "--include-path=/usr/local/cuda-10.1/include"}; result_nvrtc = nvrtcCompileProgram(program, 3, opts); diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h index f327dbae..ee7b65fe 100644 --- a/tools/library/include/cutlass/library/operation_table.h +++ b/tools/library/include/cutlass/library/operation_table.h @@ -109,7 +109,7 @@ struct GemmFunctionalKey { inline bool operator==(GemmFunctionalKey const &rhs) const { - return + return (provider == rhs.provider) && (gemm_kind == rhs.gemm_kind) && (element_compute == rhs.element_compute) && @@ -165,7 +165,7 @@ struct GemmFunctionalKeyHasher { inline static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8 - shl)); + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); } inline @@ -173,8 +173,8 @@ struct GemmFunctionalKeyHasher { IntHash hash; return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ rotl(hash(int(key.element_compute)), 3) ^ rotl(hash(int(key.element_scalar)), 4) ^ rotl(hash(int(key.element_A)), 5) ^ @@ -207,7 +207,7 @@ struct GemmPreferenceKey { GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } bool operator<(GemmPreferenceKey const &rhs) const { - return (compute_capability < rhs.compute_capability) || + return (compute_capability < rhs.compute_capability) || ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); } @@ -288,9 +288,9 @@ struct ConvFunctionalKey { layout_C(layout_C), element_accumulator(element_accumulator), element_compute(element_compute) - { } + { } - inline + inline bool operator==(ConvFunctionalKey const &rhs) const { return (provider == rhs.provider) && @@ -305,7 +305,7 @@ struct ConvFunctionalKey { (element_compute == rhs.element_compute); } - inline + inline bool operator!=(ConvFunctionalKey const &rhs) const { return !(*this == rhs); } @@ -325,7 +325,7 @@ std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctio << "element_accumulator: " << to_string(key.element_accumulator) << std::endl << "element_compute: " << to_string(key.element_compute) << std::endl << "}"; - + return out; } @@ -335,14 +335,14 @@ struct ConvFunctionalKeyHasher { inline static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8 - shl)); + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); } inline size_t operator()(ConvFunctionalKey const &key) const { IntHash hash; - return + return rotl(hash(int(key.provider)), 1) ^ rotl(hash(int(key.conv_kind)), 2) ^ rotl(hash(int(key.element_A)), 3) ^ @@ -370,11 +370,11 @@ struct ConvPreferenceKey { ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } - ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): + ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): compute_capability(cc), iterator_algorithm(iterator_algorithm) { } bool operator<(ConvPreferenceKey const &rhs) const { - return (compute_capability < rhs.compute_capability) || + return (compute_capability < rhs.compute_capability) || ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); } @@ -433,9 +433,9 @@ struct ReductionFunctionalKey { element_compute(element_compute), reduce_math_op(reduce_math_op), epilogue_math_op(epilogue_math_op) - { } + { } - inline + inline bool operator==(ReductionFunctionalKey const &rhs) const { return (provider == rhs.provider) && @@ -447,7 +447,7 @@ struct ReductionFunctionalKey { (epilogue_math_op == rhs.epilogue_math_op); } - inline + inline bool operator!=(ReductionFunctionalKey const &rhs) const { return !(*this == rhs); } @@ -459,14 +459,14 @@ struct ReductionFunctionalKeyHasher { inline static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8 - shl)); + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); } inline size_t operator()(ReductionFunctionalKey const &key) const { IntHash hash; - return + return rotl(hash(int(key.provider)), 1) ^ rotl(hash(int(key.element_workspace)), 2) ^ rotl(hash(int(key.element_accumulator)), 3) ^ @@ -505,19 +505,19 @@ using ReductionOperationFunctionalMap = std::unordered_map< class OperationTable { public: - /// Map of all operations of type kGemm + /// Map of all operations of type kGemm // provider (kCUTLASS) GemmOperationFunctionalMap gemm_operations; - /// Map of all operations of type kConv2d + /// Map of all operations of type kConv2d // provider (kCUTLASS, kReferenceHost, kReferenceDevice) ConvOperationFunctionalMap conv2d_operations; - /// Map of all operations of type kConv3d + /// Map of all operations of type kConv3d // provider (kCUTLASS, kReferenceHost, kReferenceDevice) ConvOperationFunctionalMap conv3d_operations; - /// Map of all operations of type kConv2d + /// Map of all operations of type kConv2d // provider (kCUTLASS) ReductionOperationFunctionalMap reduction_operations; diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index b0c241e7..e50f3a1b 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -38,6 +38,7 @@ #include "cutlass/library/library.h" #include "library_internal.h" #include "cutlass/gemm/dispatch_policy.hpp" +#include /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -271,7 +272,6 @@ public: /// Returns success if the operation can proceed Status can_implement( void const *configuration_ptr, void const *arguments_ptr) const override { - GemmUniversalConfiguration const *configuration = static_cast(configuration_ptr); GemmUniversalArguments const *arguments = @@ -289,7 +289,6 @@ public: configuration->problem_size.n(), configuration->problem_size.k(), configuration->batch_count); - return Operator::can_implement(args); } diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index cd5887c3..2b57dbc3 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -152,6 +152,7 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kTF32; }; + ///////////////////////////////////////////////////////////////////////////////////////////////// template struct MathOperationMap { diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index 0b37f6d5..927806d2 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -422,6 +422,8 @@ Status from_string(std::string const &str) { /////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + static struct { char const *text; char const *pretty; diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 90e86218..74087018 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -238,7 +238,9 @@ protected: DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); /// Method to profile a CUTLASS Operation Status profile_cutlass_( diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 0be2b9fe..daee0756 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -746,7 +746,13 @@ bool GemmOperationProfiler::verify_cutlass( } #endif // #if CUTLASS_ENABLE_CUBLAS - bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem); + library::GemmDescription const &gemm_desc = + static_cast(operation->description()); + + + cutlass::library::NumericTypeID element_A = gemm_desc.A.element; + cutlass::library::NumericTypeID element_B = gemm_desc.B.element; + bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem, element_A, element_B); // Update disposition to worst case verification outcome among all // verification providers which are supported @@ -912,8 +918,10 @@ bool GemmOperationProfiler::verify_with_reference_( DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) { - + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B) +{ library::GemmDescription const &gemm_desc = static_cast(operation->description()); @@ -976,13 +984,13 @@ bool GemmOperationProfiler::verify_with_reference_( problem_.alpha.data(), - gemm_desc.A.element, + element_A, gemm_desc.A.layout, gemm_desc.transform_A, ptr_A, int(gemm_workspace_.configuration.lda), - gemm_desc.B.element, + element_B, gemm_desc.B.layout, gemm_desc.transform_B, ptr_B, @@ -1010,7 +1018,6 @@ bool GemmOperationProfiler::verify_with_reference_( results_.back().verification_map[provider] = Disposition::kNotRun; continue; } - results_.back().status = status; if (provider == library::Provider::kReferenceHost) { diff --git a/tools/util/include/cutlass/util/print_error.hpp b/tools/util/include/cutlass/util/print_error.hpp index aeeda92d..9eed9d14 100644 --- a/tools/util/include/cutlass/util/print_error.hpp +++ b/tools/util/include/cutlass/util/print_error.hpp @@ -62,7 +62,6 @@ template matrix_inf_norm_result matrix_inf_norm(cute::Tensor const& host_matrix) { - using std::abs; using error_type = decltype(std::declval().inf_norm); using element_type = typename EngineType::value_type; @@ -74,14 +73,25 @@ matrix_inf_norm(cute::Tensor const& host_matrix) const int64_t num_rows = cute::size<0>(host_matrix); const int64_t num_cols = cute::size<1>(host_matrix); - for(int64_t i = 0; i < num_rows; ++i) { + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + for (int64_t i = 0; i < num_rows; ++i) { error_type row_abs_sum = 0.0; for(int64_t j = 0; j < num_cols; ++j) { - row_abs_sum += abs(host_matrix(i, j)); + row_abs_sum += abs_fn(host_matrix(i, j)); } - if(std::isnan(row_abs_sum)) { + if (std::isnan(row_abs_sum)) { found_nan = true; - } else { + } + else { inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; } } @@ -95,10 +105,19 @@ matrix_inf_norm_result matrix_diff_inf_norm(cute::Tensor const& X, cute::Tensor const& Y) { - using std::abs; using error_type = decltype(std::declval().inf_norm); using element_type = typename EngineType::value_type; + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + assert(cute::size<0>(X) == cute::size<0>(Y)); assert(cute::size<1>(X) == cute::size<1>(Y)); @@ -110,15 +129,16 @@ matrix_diff_inf_norm(cute::Tensor const& X, error_type inf_norm = 0.0; bool found_nan = false; - for(int64_t i = 0; i < num_rows; ++i) { + for (int64_t i = 0; i < num_rows; ++i) { error_type row_abs_sum = 0.0; - for(int64_t j = 0; j < num_cols; ++j) { - row_abs_sum += error_type(abs(element_type(X(i,j)) - - element_type(Y(i,j)))); + for (int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += error_type(abs_fn(element_type(X(i,j)) - + element_type(Y(i,j)))); } - if(std::isnan(row_abs_sum)) { + if (std::isnan(row_abs_sum)) { found_nan = true; - } else { + } + else { inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; } } @@ -130,7 +150,7 @@ template -void +auto print_matrix_multiply_mollified_relative_error( char const A_value_type_name[], cute::Tensor const& A, @@ -158,13 +178,13 @@ print_matrix_multiply_mollified_relative_error( using std::cout; using cute::shape; cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' - << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' - << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' - << std::scientific - << "Infinity norm of A: " << A_norm << '\n' - << "Infinity norm of B: " << B_norm << '\n' - << "Infinity norm of C: " << C_norm << '\n' - << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; + << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' + << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' + << std::scientific + << "Infinity norm of A: " << A_norm << '\n' + << "Infinity norm of B: " << B_norm << '\n' + << "Infinity norm of C: " << C_norm << '\n' + << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; if(A_norm_times_B_norm == 0.0) { cout << "Mollified relative error: " << relative_error << '\n'; @@ -173,15 +193,16 @@ print_matrix_multiply_mollified_relative_error( } if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { - cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; } + return relative_error; } template -void +auto print_matrix_multiply_mollified_relative_error( const char value_type_name[], const cute::Tensor& A, @@ -189,7 +210,7 @@ print_matrix_multiply_mollified_relative_error( const cute::Tensor& C_computed, const cute::Tensor& C_expected) { - print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, + return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, value_type_name, C_computed, C_expected); } @@ -314,7 +335,7 @@ print_relative_error( bool print_error = true, double error_margin = 0.00001) { assert(size(data) == size(reference)); - return print_relative_error(static_cast(size(data)), - data, reference, + return print_relative_error(static_cast(size(data)), + data, reference, print_verbose, print_error, error_margin); } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 93e559e1..05b877a2 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -1713,7 +1713,7 @@ void BlockFillSequential( Layout layout = Layout::packed(size); TensorView view(ptr, layout, size); - Array c; + Array c{}; c[0] = v; TensorFillLinear(view, c, s); diff --git a/tools/util/include/cutlass/util/reference/host/conv.hpp b/tools/util/include/cutlass/util/reference/host/conv.hpp index cbca2df6..202091d9 100644 --- a/tools/util/include/cutlass/util/reference/host/conv.hpp +++ b/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -41,6 +41,8 @@ #include "cute/tensor.hpp" +#include + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::reference::host { @@ -93,7 +95,8 @@ template< class TensorAlpha_, class TensorBeta_, class TensorBias_, - class ActivationFunctor_ = cutlass::epilogue::thread::Identity> + class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> struct ConvEpilogueFusionParams { using ElementAcc = ElementAcc_; using ElementScalar = ElementScalar_; @@ -104,7 +107,6 @@ struct ConvEpilogueFusionParams { using TensorBeta = TensorBeta_; using TensorBias = TensorBias_; using ActivationFunctor = ActivationFunctor_; - ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); @@ -155,6 +157,7 @@ struct ConvReferenceImpl { // Epilogue activation operation ActivationFunctor epi_activation; + ConvReferenceImpl( TensorA const& tensor_a, TensorB const& tensor_b, @@ -201,7 +204,7 @@ private: #pragma omp parallel for collapse(2) #endif for (int32_t n = 0; n < N; ++n) { - for (int32_t q = 0; q < Q; ++q) { + for (int32_t q = 0; q < Q; ++q) { for (int32_t k = 0; k < K; ++k) { auto accumulator = ElementAcc(0); for (int32_t s = 0; s < S; ++s) { @@ -226,6 +229,7 @@ private: } } } + } // Specialization for 2D fprop kernel @@ -272,6 +276,7 @@ private: } } } + } // Specialization for 3D fprop kernel @@ -325,6 +330,7 @@ private: } } } + } // Specialization for 1D dgrad kernel @@ -371,6 +377,7 @@ private: } } } + } // Specialization for 2D dgrad kernel @@ -424,11 +431,14 @@ private: if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } + output = epi_activation(output); + tensor_d_(c, w, h, n) = output_converter(output); } } } } + } // Specialization for 3D dgrad kernel @@ -501,6 +511,7 @@ private: } } } + } // Specialization for 1D wgrad kernel diff --git a/tools/util/include/cutlass/util/reference/host/convolution.h b/tools/util/include/cutlass/util/reference/host/convolution.h index 07d3681f..f28b4a65 100644 --- a/tools/util/include/cutlass/util/reference/host/convolution.h +++ b/tools/util/include/cutlass/util/reference/host/convolution.h @@ -197,7 +197,7 @@ void Depsep_Fprop(cutlass::TensorView tensor_A, } //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Dgrad +/// Dgrad / Deconv //////////////////////////////////////////////////////////////////////////////////////////////////// /// dx = dgrad(dy, w) @@ -221,7 +221,8 @@ void Conv2dDgrad( TensorRef tensor_dx_in, TensorRef tensor_dx_out, ElementCompute alpha, - ElementCompute beta) { + ElementCompute beta, + bool is_deconv = false) { ConvertOp convert_op; InnerProductOp inner_product_op; @@ -272,7 +273,8 @@ void Conv2dDgrad( if (p < problem_size.P && q < problem_size.Q) { ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); - ElementB b = tensor_w.at(cutlass::make_Coord(k, r, s, c)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, r, s, c)); acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); } @@ -420,6 +422,7 @@ void Conv2d( >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); break; + case conv::Operator::kDeconv: case conv::Operator::kDgrad: Conv2dDgrad< ElementA, LayoutA, @@ -429,7 +432,7 @@ void Conv2d( ElementAccumulator, ElementD, ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); break; case conv::Operator::kWgrad: @@ -537,7 +540,7 @@ void Conv3dFprop( } //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Dgrad +/// Dgrad / Deconv //////////////////////////////////////////////////////////////////////////////////////////////////// /// dx = dgrad(dy, w) @@ -560,7 +563,8 @@ void Conv3dDgrad( TensorRef tensor_dx_in, TensorRef tensor_dx_out, ElementCompute alpha, - ElementCompute beta) { + ElementCompute beta, + bool is_deconv = false) { ConvertOp convert_op; InnerProductOp inner_product_op; @@ -604,8 +608,8 @@ void Conv3dDgrad( if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); - ElementB b = tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); - + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); } } @@ -760,6 +764,7 @@ void Conv3d( >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); break; + case conv::Operator::kDeconv: case conv::Operator::kDgrad: Conv3dDgrad< ElementA, LayoutA, @@ -768,7 +773,7 @@ void Conv3d( ElementCompute, ElementAccumulator, ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); break; case conv::Operator::kWgrad: diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 978b666c..84aa9363 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -35,10 +35,11 @@ #pragma once ///////////////////////////////////////////////////////////////////////////////////////////////// - +#include "cutlass/gemm/gemm.h" #include "cutlass/complex.h" #include "cutlass/numeric_conversion.h" #include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" #include "cute/tensor.hpp" @@ -115,7 +116,6 @@ struct GettEpilogueParams { using LayoutC = typename TensorC::layout_type; using EngineD = typename TensorD::engine_type; using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; ElementScalar alpha = ElementScalar(1);