diff --git a/CHANGELOG.md b/CHANGELOG.md index c98cdb51..1ba870eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,4 @@ # NVIDIA CUTLASS Changelog - ## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03) - [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). @@ -8,7 +7,12 @@ + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) - A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. -- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- Improve [mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md). + + Added a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. + + Added [layout pre-shuffling](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L50-55) to optimize memory loading. + + Added [interleaved conversion](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu#L50-52) for `{INT4, UINT4, INT8}` x `{FP16, BF16}`. + + Other general optimizations. +- The suffixes of the mixed input kernel schedules have been removed. Use `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` and `KernelTmaWarpSpecializedCooperative` instead. - [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). - [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). - [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. @@ -18,7 +22,27 @@ - [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). - A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) - Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). - Various improvements and fixed from the community and CUTLASS team. Thanks to everyone who submitted PRs! +- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! + +- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) +- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence: + + [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411). + + [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456). +- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. +- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). +- Support for residual add (beta != 0) in convolution kernels. +- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. +- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- Better support for MSVC as a host compiler. +- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. +- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. ## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25) @@ -51,7 +75,7 @@ + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! - Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. -- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x +- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. - 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. @@ -82,7 +106,7 @@ * [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. * [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}. * [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors. -* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors. +* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors. * Profiler support for lower-aligned Hopper GEMMs. * Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion). * Sub-Byte type fixes and improvements. @@ -159,10 +183,10 @@ * [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary. * Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N. * [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added - * [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM. + * [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM. * The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration. * [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded. - * The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration. + * The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration. * [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/). * [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115). * Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers). @@ -173,13 +197,13 @@ * CUDA 10.2 ## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23) -* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. +* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. * Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too. * Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance. * [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing. * [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues. * [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue. -* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes: +* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes: * kSingleGroup: output channel per group is multiple of Threadblock tile N. * kMultipleGroup: Threadblock tile N is multiple of output channel per group. * [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number. @@ -235,7 +259,7 @@ * [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu) * **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu * [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu) - * [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu) + * [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu) * [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h) * **Grouped GEMM:** similar to batched GEMM with distinct problem size per group * [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM @@ -274,7 +298,7 @@ * [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) * [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) * 64b tensor strides and leading dimensions support for GEMMs - * Affine rank=2 matrix layouts + * Affine rank=2 matrix layouts * Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h) * Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM. * [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation @@ -289,7 +313,7 @@ * Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations * Performance improvement for FP16 tensor core kernels * Bug fixes - * Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. + * Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. * Updated minimum CUDA Toolkit requirement to 10.2 * [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended * Corrections and bug fixes reported by the CUTLASS community @@ -308,7 +332,7 @@ * [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md) * Corrections and bug fixes reported by the CUTLASS community * Thank you for filing these issues! - + ## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19) * Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs @@ -316,7 +340,7 @@ * Data type: FP32, complex, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32 * Spatial dimensions: 1-D, 2-D, and 3-D * Layout: NHWC, NCxHWx - * Implicit GEMM convolution components: + * Implicit GEMM convolution components: * Global memory iterators supporting Fprop, Dgrad, and Wgrad * `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture * `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures @@ -332,17 +356,17 @@ * Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code * [Floating-point constants](./include/cutlass/constants.h) * NVIDIA Ampere GPU Architecture examples and documentation: - * [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and + * [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and * [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu) * Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue) ## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) - * Fast Tensor Core operations: + * Fast Tensor Core operations: * Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends) * Tensor Float 32, BFloat16, and double-precision data types * Mixed integer data types (int8, int4, bin1) - * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution) + * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution) * Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required) * Features: * SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM diff --git a/CMakeLists.txt b/CMakeLists.txt index e61b66a8..e9c501bc 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -59,6 +59,18 @@ project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_C ################################################################################ +if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") + set(CUTLASS_GNU_HOST_COMPILE ON CACHE BOOL "Using GNU tools for host code compilation") +endif() +if (CMAKE_CXX_COMPILER_ID MATCHES "[Cc]lang") + set(CUTLASS_CLANG_HOST_COMPILE ON CACHE BOOL "Using Clang tools for host code compilation") +endif() +if (CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + set(CUTLASS_MSVC_HOST_COMPILE ON CACHE BOOL "Using MSVC tools for host code compilation") +endif() + +################################################################################ + include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -67,11 +79,11 @@ elseif (CUDA_VERSION VERSION_LESS 11.4) message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.") endif() -if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3) +if(CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3) message(FATAL_ERROR "GCC version must be at least 7.3!") endif() -if (CUDA_COMPILER MATCHES "[Cc]lang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) +if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") endif() find_package(Doxygen QUIET) @@ -85,13 +97,10 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -if(CUTLASS_NATIVE_CUDA) - set(CMAKE_CUDA_STANDARD 17) - set(CMAKE_CUDA_STANDARD_REQUIRED ON) - list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr) -else() - list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++17) -endif() +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr) if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE) @@ -146,13 +155,13 @@ endif() ################################################################################ set(CUTLASS_NVCC_ARCHS_SUPPORTED "") -if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70 72 75 80 86 87) endif() -if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 89 90) endif() -if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") @@ -246,7 +255,7 @@ set(KERNEL_FILTER_FILE "" CACHE STRING "KERNEL FILTER FILE FULL PATH") if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS) # If a kernel filter file is specified, we want to generate and then # filter on the entire kernel set, not the default kernel - # (sub)set. The user may have overridden CUTLASS_LIBRRARY_KERNELS, in which + # (sub)set. The user may have overridden CUTLASS_LIBRARY_KERNELS, in which # case the resulting kernel set will be the intersection of the two # options differenced against CUTLASS_LIBRARY_IGNORE_KERNELS. set(CUTLASS_LIBRARY_KERNELS_INIT "*") @@ -375,15 +384,22 @@ endif() # Warnings-as-error exceptions and warning suppressions for Clang builds -if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=implicit-int-conversion ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=implicit-int-conversion" ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pass-failed ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=pass-failed" ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=inconsistent-missing-override ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=inconsistent-missing-override" ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion ") - list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-sign-conversion" ) +if (CUTLASS_CLANG_HOST_COMPILE) + + set(FLAGS_TO_ADD + "-Wno-error=implicit-int-conversion" + "-Wno-error=pass-failed" + "-Wno-error=inconsistent-missing-override" + "-Wno-sign-conversion" + "-Wno-unused-parameter" + ) + + foreach(FLAG ${FLAGS_TO_ADD}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLAG}") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "${FLAG}") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS "${FLAG}") + endforeach() + endif() if (NOT MSVC AND CUTLASS_NVCC_KEEP) @@ -396,9 +412,9 @@ endif() if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1) - if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) + if (CUTLASS_GNU_HOST_COMPILE OR CUTLASS_CLANG_HOST_COMPILE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c) - elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC")) + elseif(CUTLASS_MSVC_HOST_COMPILE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2) endif() endif() @@ -423,19 +439,8 @@ if (NOT CMAKE_BUILD_TYPE MATCHES "Release") list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo) endif() -#Report CUDA build flags -if (CUDA_COMPILER MATCHES "[Cc]lang") - if(CUTLASS_CUDA_CLANG_FLAGS) - message(STATUS "Using CLANG flags: ${CUTLASS_CUDA_CLANG_FLAGS}") - endif() -else() - if(CUTLASS_CUDA_NVCC_FLAGS) - message(STATUS "Using NVCC flags: ${CUTLASS_CUDA_NVCC_FLAGS}") - endif() -endif() - -if(CUDA_COMPILER MATCHES "[Cc]lang") - if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) +if (CUTLASS_CLANG_DEVICE_COMPILE) + if (NOT CUTLASS_CLANG_HOST_COMPILE) message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) endif() @@ -451,12 +456,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument) - string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION}) - list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR) - list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}) - # needed for libcublasLt.so in case it's installed in the same location as libcudart.so # dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags) # Otherwise linker uses RUNPATH and that does not propagate to loaded libs. @@ -464,11 +465,26 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") link_libraries(nvidia::cudart) link_libraries(nvidia::cuda_driver) + endif() +#Report CUDA build flags +if (CUTLASS_CLANG_DEVICE_COMPILE AND CUTLASS_CUDA_CLANG_FLAGS) + set(__FLAG_GROUP Clang) + set(__FLAG_LIST CUTLASS_CUDA_CLANG_FLAGS) +else(CUTLASS_NVCC_DEVICE_COMPILE AND CUTLASS_CUDA_NVCC_FLAGS) + set(__FLAG_GROUP NVCC) + set(__FLAG_LIST CUTLASS_CUDA_NVCC_FLAGS) +endif() + +set(__FLAG_DISPLAY_STRING "") +set(__FLAG_DISPLAY_SEPARATOR) +list(JOIN ${__FLAG_LIST} "\n " __FLAG_DISPLAY_STRING) +message(STATUS "Using the following ${__FLAG_GROUP} flags: \n ${__FLAG_DISPLAY_STRING}") + # Known gcc 8.1-8.3 SFINAE issue (fixed in gcc 8.4), check https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87748 # Also see https://github.com/NVIDIA/nccl/issues/835 for nvtx3.hpp -if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1 AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS_EQUAL 8.3) +if (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1 AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS_EQUAL 8.3) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") endif() @@ -478,12 +494,10 @@ if (${CMAKE_CXX_COMPILER_ID} MATCHES "PGI" OR ${CMAKE_CXX_COMPILER_ID} MATCHES " set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Mint128 ") endif() -if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18) - # CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this - # property for CMake 3.18+, so we request the NEW behavior for correct compatibility. - # https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 - cmake_policy(SET CMP0104 NEW) -endif() +# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this +# property for CMake 3.18+, so we request the NEW behavior for correct compatibility. +# https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 +cmake_policy(SET CMP0104 NEW) if (MSVC) @@ -519,55 +533,21 @@ function(cutlass_apply_cuda_gencode_flags TARGET) set(ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS_ENABLED}) endif() - set(NVCC_FLAGS) - set(CLANG_FLAGS) set(__CMAKE_CUDA_ARCHS) foreach(ARCH ${ARCHS_ENABLED}) - list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH}) set(CODES) if(CUTLASS_NVCC_EMBED_CUBIN) - list(APPEND CODES sm_${ARCH}) list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real) endif() - if(CUTLASS_NVCC_EMBED_PTX) - list(APPEND CODES compute_${ARCH}) + if(CUTLASS_NVCC_EMBED_PTX AND NOT CUTLASS_CLANG_DEVICE_COMPILE) + # If we're using clang for device compilation, the ptx is inserted + # via another command line option and the `-virtual` flags will cause an error. list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual) endif() list(JOIN CODES "," CODES_STR) - list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}]) endforeach() - if (NOT __SM_ARCHS) - if (CUDA_COMPILER MATCHES "[Cc]lang") - target_compile_options( - ${TARGET} - PRIVATE - $<$:${CLANG_FLAGS}> - ) - elseif(CMAKE_VERSION GREATER_EQUAL 3.18) - set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS}) - else() - target_compile_options( - ${TARGET} - PRIVATE - $<$:${NVCC_FLAGS}> - ) - endif() - else() - list(JOIN CLANG_FLAGS " " CLANG_FLAGS_STR) - list(JOIN NVCC_FLAGS " " STR_NVCC_FLAGS) - if (CUDA_COMPILER MATCHES "[Cc]lang") - if(${TARGET} MATCHES ".*\.cpp") - set_source_files_properties(${TARGET} PROPERTIES COMPILE_FLAGS ${CLANG_FLAGS_STR}) - endif() - elseif(CMAKE_VERSION GREATER_EQUAL 3.18) - set_source_files_properties(${TARGET} PROPERTIES CUDA_ARCHITECTURES ${STR_NVCC_FLAGS}) - else() - if(${TARGET} MATCHES ".*\.cu") - set_source_files_properties(${TARGET} PROPERTIES COMPILE_FLAGS ${STR_NVCC_FLAGS}) - endif() - endif() - endif() + set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS}) endfunction() @@ -588,8 +568,8 @@ set(__CUTLASS_CUDA_NVCC_FLAGS_DEBUG ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG} CACHE INTER function(cutlass_apply_standard_compile_options TARGET) - if(CUDA_COMPILER MATCHES "[Cc]lang") - set(CUDA_COMPILE_LANGUAGE CXX) + if(CUTLASS_CLANG_DEVICE_COMPILE) + set(CUDA_COMPILE_LANGUAGE CUDA) set(_FLAGS ${__CUTLASS_CUDA_FLAGS} ${__CUTLASS_CUDA_CLANG_FLAGS}) set(_FLAGS_RELEASE ${__CUTLASS_CUDA_FLAGS_RELEASE} ${__CUTLASS_CUDA_CLANG_FLAGS_RELEASE}) set(_FLAGS_RELWITHDEBINFO ${__CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${__CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO}) @@ -682,8 +662,6 @@ target_include_directories( $ $ $ - $ - $ ) # Mark CTK headers as system to supress warnings from them @@ -825,7 +803,7 @@ function(cutlass_add_executable_tests NAME TARGET) # TEST_SETS_SUPPORTED: A list of test set names these tests support. # - set(options DISABLE_EXECUTABLE_INSTALL_RULE) + set(options DISABLE_EXECUTABLE_INSTALL_RULE DO_NOT_LOWERCASE_TEST_NAME) set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX) set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -915,11 +893,15 @@ function(cutlass_add_executable_tests NAME TARGET) foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS) if (CMD_COUNT GREATER 1) - string(TOLOWER "${NAME}_${CMD_OPTIONS_VAR}" TESTCASE_NAME) + set(TESTCASE_NAME "${NAME}_${CMD_OPTIONS_VAR}") else() - string(TOLOWER "${NAME}" TESTCASE_NAME) + set(TESTCASE_NAME "${NAME}") endif() + if (NOT __DO_NOT_LOWERCASE_TEST_NAME) + string(TOLOWER "${TESTCASE_NAME}" TESTCASE_NAME) + endif() + # The following rigmarole is needed to deal with spaces and possible quotes in # command line arguments. The options are passed "by reference" as the actual # variable names holding the real options. We then expand these in a way that diff --git a/CUDA.cmake b/CUDA.cmake index 755b7476..7e91adb8 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -26,49 +26,46 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if(CUDA_COMPILER MATCHES "[Cc]lang") - set(CUTLASS_NATIVE_CUDA_INIT ON) -elseif(CMAKE_VERSION VERSION_LESS 3.12.4) - set(CUTLASS_NATIVE_CUDA_INIT OFF) -else() - set(CUTLASS_NATIVE_CUDA_INIT ON) +if (CUDA_COMPILER MATCHES "[Cc]lang") + message(WARNING "CUDA_COMPILER flag is deprecated, set CMAKE_CUDA_COMPILER to desired compiler executable.") + set(__CLANG_DEVICE_COMPILATION_REQUESTED ON) +elseif(CUDA_COMPILER) + message(WARNING "Deprecated flag CUDA_COMPILER used with unknown argument ${CUDA_COMPILER}, ignoring.") endif() -set(CUTLASS_NATIVE_CUDA ${CUTLASS_NATIVE_CUDA_INIT} CACHE BOOL "Utilize the CMake native CUDA flow") - -if(NOT DEFINED ENV{CUDACXX} AND NOT DEFINED ENV{CUDA_BIN_PATH} AND DEFINED ENV{CUDA_PATH}) - # For backward compatibility, allow use of CUDA_PATH. - set(ENV{CUDACXX} $ENV{CUDA_PATH}/bin/nvcc) +if (__CLANG_DEVICE_COMPILATION_REQUESTED AND NOT DEFINED CMAKE_CUDA_COMPILER) + set(CMAKE_CUDA_COMPILER clang++) # We will let the system find Clang or error out endif() -if(CUTLASS_NATIVE_CUDA) +enable_language(CUDA) +find_package(CUDAToolkit REQUIRED) - enable_language(CUDA) - - if(NOT CUDA_VERSION) - set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) - endif() - if(NOT CUDA_TOOLKIT_ROOT_DIR) - get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE) - endif() +if(NOT CUDA_VERSION) + # For backward compatibility with older CMake code. + set(CUDA_VERSION ${CUDAToolkit_VERSION}) + set(CUDA_VERSION_MAJOR ${CUDAToolkit_VERSION_MAJOR}) + set(CUDA_VERSION_MINOR ${CUDAToolkit_VERSION_MINOR}) +endif() +if(NOT CUDA_TOOLKIT_ROOT_DIR) + # In some scenarios, such as clang device compilation, the toolkit root may not be set, so we + # force it here to the nvcc we found via the CUDAToolkit package. + get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDAToolkit_NVCC_EXECUTABLE}/../.." ABSOLUTE) +endif() +if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])") + set(CUTLASS_NVCC_DEVICE_COMPILE ON CACHE BOOL "Using nvcc tools for device compilation") +elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang") + set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation") else() + message(FATAL_ERROR "Uknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.") +endif() - find_package(CUDA REQUIRED) - # We workaround missing variables with the native flow by also finding the CUDA toolkit the old way. - - if(NOT CMAKE_CUDA_COMPILER_VERSION) - set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION}) - endif() - +if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30") + message(FATAL_ERROR "Clang device compilation for CUTLASS requires CMake 3.30 or higher.") endif() if (CUDA_VERSION VERSION_LESS 9.2) - message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.") -endif() -if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang") - set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) - message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + message(FATAL_ERROR "CUDA 9.2+ required, found ${CUDA_VERSION}.") endif() find_library( @@ -211,16 +208,6 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) # Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include # paths by default, so we add it explicitly here. -function(cutlass_correct_source_file_language_property) - if(CUDA_COMPILER MATCHES "[Cc]lang") - foreach(File ${ARGN}) - if(File MATCHES ".*\.cu$") - set_source_files_properties(${File} PROPERTIES LANGUAGE CXX) - endif() - endforeach() - endif() -endfunction() - if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all") set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON) else() @@ -306,18 +293,13 @@ function(cutlass_add_library NAME) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) - add_library(${NAME} ${TARGET_SOURCE_ARGS} "") - else() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS} "") - endif() + + add_library(${NAME} ${TARGET_SOURCE_ARGS} "") cutlass_apply_standard_compile_options(${NAME}) + if (NOT __SKIP_GENCODE_FLAGS) - cutlass_apply_cuda_gencode_flags(${NAME}) + cutlass_apply_cuda_gencode_flags(${NAME}) endif() target_compile_features( @@ -359,13 +341,7 @@ function(cutlass_add_executable NAME) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) - add_executable(${NAME} ${TARGET_SOURCE_ARGS}) - else() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS}) - endif() + add_executable(${NAME} ${TARGET_SOURCE_ARGS}) cutlass_apply_standard_compile_options(${NAME}) cutlass_apply_cuda_gencode_flags(${NAME}) @@ -388,7 +364,6 @@ function(cutlass_target_sources NAME) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) target_sources(${NAME} ${TARGET_SOURCE_ARGS}) endfunction() diff --git a/README.md b/README.md index efe47872..e61335f2 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. + # What's New in CUTLASS 3.6 CUTLASS 3.6.0 is an update to CUTLASS adding: diff --git a/examples/13_two_tensor_op_fusion/CMakeLists.txt b/examples/13_two_tensor_op_fusion/CMakeLists.txt index 0b1e2cdf..6819a976 100644 --- a/examples/13_two_tensor_op_fusion/CMakeLists.txt +++ b/examples/13_two_tensor_op_fusion/CMakeLists.txt @@ -80,4 +80,3 @@ foreach(FUSION_GEMM_EXAMPLE add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE}) endforeach() - diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index 885aaceb..55852730 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -59,11 +59,11 @@ // Also, we don't check the index value is legal and index array point is valid // for the sake of the performance. -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include diff --git a/examples/39_gemm_permute/layouts.h b/examples/39_gemm_permute/layouts.h index ffb27c4e..3632ec0a 100644 --- a/examples/39_gemm_permute/layouts.h +++ b/examples/39_gemm_permute/layouts.h @@ -33,11 +33,7 @@ computing reference permutations of 4/5D tensors when source data is column-major. */ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include "assert.h" -#endif #include "cutlass/cutlass.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" diff --git a/examples/41_fused_multi_head_attention/debug_utils.h b/examples/41_fused_multi_head_attention/debug_utils.h index 90c0a69b..efca4f13 100644 --- a/examples/41_fused_multi_head_attention/debug_utils.h +++ b/examples/41_fused_multi_head_attention/debug_utils.h @@ -30,8 +30,8 @@ **************************************************************************************************/ #pragma once -#include -#include +#include +#include #include //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h index f8f06dfe..e166af4d 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -43,11 +43,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/aligned_buffer.h" #include "cutlass/array.h" @@ -57,12 +53,9 @@ #include "cutlass/layout/vector.h" #include "cutlass/numeric_types.h" #include "cutlass/tensor_coord.h" - #include "cutlass/gemm/gemm.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/transform/threadblock/regular_tile_iterator.h" - #include "cutlass/epilogue/threadblock/epilogue_base.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" #include "cutlass/numeric_types.h" diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h index 24530a0f..6860ee9e 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -43,11 +43,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/aligned_buffer.h" #include "cutlass/array.h" @@ -57,16 +53,12 @@ #include "cutlass/layout/vector.h" #include "cutlass/numeric_types.h" #include "cutlass/tensor_coord.h" - #include "cutlass/gemm/gemm.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/transform/threadblock/regular_tile_iterator.h" - #include "cutlass/epilogue/threadblock/epilogue_base.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" #include "cutlass/numeric_types.h" - #include "cutlass/array.h" #include "cutlass/cutlass.h" #include "cutlass/epilogue/thread/scale_type.h" diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h index 22779b59..5a2f928a 100644 --- a/examples/41_fused_multi_head_attention/fmha_grouped.h +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -550,7 +550,7 @@ public: auto prologueV = [&](int blockN) { typename MM1::Mma::IteratorB iterator_V( - typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], {problem_size_1_k, problem_size_1_n}, thread_id(), @@ -719,7 +719,7 @@ public: } typename MM1::Mma::IteratorB iterator_V( - typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], {problem_size_1_k, problem_size_1_n}, thread_id(), @@ -761,15 +761,15 @@ public: using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, output_t, output_accum_t>::type, output_accum_t, DefaultOp::kCount, typename DefaultOp::ElementAccumulator, output_accum_t, - kIsFirst, - kIsLast, + kIsFirst::value, + kIsLast::value, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock:: EpiloguePipelined< @@ -777,7 +777,7 @@ public: typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, typename DefaultEpilogue:: @@ -795,7 +795,7 @@ public: int col = blockN * MM1::Mma::Shape::kN; auto source_iter = createOutputAccumIter(col); auto dest_iter = gemm_kernel_utils::call_conditional< - kIsLast, + kIsLast::value, decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); @@ -817,8 +817,8 @@ public: } if (kKeepOutputInRF) { - const bool kIsFirst = true; - const bool kIsLast = true; + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; using DefaultEpilogue = typename MM1::DefaultEpilogue; using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; using ElementCompute = typename DefaultOp::ElementCompute; diff --git a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h index 18e61f6f..a770e0b6 100644 --- a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h +++ b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -55,13 +55,14 @@ #define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ { \ if (BOOL_V) { \ - constexpr bool BOOL_NAME = true; \ + using BOOL_NAME = std::true_type; \ F(); \ } else { \ - constexpr bool BOOL_NAME = false; \ + using BOOL_NAME = std::false_type; \ F(); \ } \ } + #define DISPATCH_ARCHTAG(CC, func) \ { \ if (CC >= 80) { \ diff --git a/examples/41_fused_multi_head_attention/kernel_backward.h b/examples/41_fused_multi_head_attention/kernel_backward.h index e7372f13..6fd94a6c 100644 --- a/examples/41_fused_multi_head_attention/kernel_backward.h +++ b/examples/41_fused_multi_head_attention/kernel_backward.h @@ -32,6 +32,7 @@ #pragma once #include +#include #include #include @@ -85,8 +86,6 @@ #include "gemm/mma_from_smem.h" #include "transform/tile_smem_loader.h" -#include - using namespace gemm_kernel_utils; namespace { @@ -1956,7 +1955,8 @@ struct AttentionBackwardKernel { // no-op epilogue operator - just casting and storing contents of // accum to global memory - typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op( + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1}); typename MatmulDOIVJ::BiasGradEpilogue epilogue( shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); epilogue(output_op, output_iter, accum, output_iter); @@ -2211,7 +2211,7 @@ struct AttentionBackwardKernel { incrIteration(p, query_start, key_start, next_query, next_key); DISPATCH_BOOL( next_key != key_start, kForceReloadK, ([&]() { - prologueQkNextIteration( + prologueQkNextIteration( shared_storage, p, next_query, next_key, warp_id, lane_id); })); } @@ -2342,7 +2342,7 @@ struct AttentionBackwardKernel { thread_id, cutlass::MatrixCoord{0, 0}); - MatmulQK::Mma::prologue( + MatmulQK::Mma::template prologue( shared_storage.mm_qk_k(), shared_storage.mm_qk_q(), iterator_A, @@ -2369,6 +2369,7 @@ struct AttentionBackwardKernel { p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, thread_id); + accumulateInGmem( shared_storage.gradV_epilogue_final(), output_frags.gradV, @@ -2406,7 +2407,7 @@ struct AttentionBackwardKernel { int thread_id = 32 * warp_id + lane_id; DISPATCH_BOOL( first, kIsFirst, ([&]() { - static constexpr auto ScaleType = kIsFirst + static constexpr auto ScaleType = kIsFirst::value ? cutlass::epilogue::thread::ScaleType::Nothing : cutlass::epilogue::thread::ScaleType::NoBetaScaling; using EpilogueOutputOp = diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h index 4c80f549..71d79415 100644 --- a/examples/41_fused_multi_head_attention/kernel_forward.h +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -38,6 +38,7 @@ #include #include +#include #include #include "cutlass/fast_math.h" @@ -71,8 +72,6 @@ #include "gemm_kernel_utils.h" #include "transform/tile_smem_loader.h" -#include - using namespace gemm_kernel_utils; namespace { @@ -1036,15 +1035,15 @@ struct AttentionKernel { using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, output_t, output_accum_t>::type, output_accum_t, DefaultOp::kCount, typename DefaultOp::ElementAccumulator, ElementCompute, - kIsFirst, - kIsLast, + kIsFirst::value, + kIsLast::value, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock:: EpiloguePipelined< @@ -1052,7 +1051,7 @@ struct AttentionKernel { typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename cutlass::platform::conditional< - kIsLast, + kIsLast::value, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, typename DefaultEpilogue:: @@ -1070,7 +1069,7 @@ struct AttentionKernel { int col = blockN * MM1::Mma::Shape::kN; auto source_iter = createOutputAccumIter(col); auto dest_iter = call_conditional< - kIsLast, + kIsLast::value, decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h index afa20871..1acb4a2d 100644 --- a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h @@ -39,11 +39,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -53,12 +49,9 @@ #include "cutlass/tensor_coord.h" #include "cutlass/aligned_buffer.h" #include "cutlass/functional.h" - #include "cutlass/gemm/gemm.h" - #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/transform/threadblock/regular_tile_iterator.h" - #include "cutlass/epilogue/threadblock/epilogue_base.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py index a1f2998c..6474d95c 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py @@ -43,7 +43,7 @@ class gen_test: def gen_cpp_sample(self): code = "/* Auto Generated code - Do not edit.*/\n" - code += "#include \n" + code += "#include \n" code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n" code += "#include \"cutlass/cutlass.h\" \n" diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py index 117179ed..db1ec4c7 100644 --- a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py @@ -380,7 +380,7 @@ class gen_one_API: def gen_CUTLASS_irrelevant_API(self): code = "" code += "#include \n" - code += "#include \n" + code += "#include \n" param_name = "Fused" + str(self.b2b_num) + "xGemm_" for i in range(self.b2b_num): diff --git a/examples/45_dual_gemm/test_run.h b/examples/45_dual_gemm/test_run.h index 2bd6c720..4a58a3a1 100644 --- a/examples/45_dual_gemm/test_run.h +++ b/examples/45_dual_gemm/test_run.h @@ -66,7 +66,7 @@ int testRun(int arch, std::vector & test_funcs, const std::string & return -1; } - if (!(props.major == arch_major && props.minor == arch_minor)) { + if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) { supported = false; } diff --git a/examples/45_dual_gemm/threadblock/dual_epilogue.h b/examples/45_dual_gemm/threadblock/dual_epilogue.h index cd2288af..3ef1c6d3 100644 --- a/examples/45_dual_gemm/threadblock/dual_epilogue.h +++ b/examples/45_dual_gemm/threadblock/dual_epilogue.h @@ -38,11 +38,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index 884f3535..0a74e02a 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -45,18 +45,18 @@ and BEFORE scatter operations are applied. */ -#include -#include -#include -#include -#include -#include - +#include +#include +#include +#include +#include #include #include #include #include +#include + #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -64,7 +64,6 @@ #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" - #include "cutlass/util/command_line.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu index c1d3caeb..ab82b40c 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -619,7 +619,6 @@ int main(int argc, char const **args) { << "later (compute capability 90 or greater).\n"; return 0; } - // // Parse options // @@ -681,4 +680,4 @@ int main(int argc, char const **args) { return 0; } -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu index 5fc96d4e..40fa6894 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -559,4 +559,4 @@ int main(int argc, char const **args) { return 0; } -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index 8d22c75f..ecb4f41c 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -9,12 +9,17 @@ This first version only supports mixed type GEMMs using TMA. ## Performance -While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type. +While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type. The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now. + +Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details. + + We are currently optimizing the following cases: 1. Memory bound cases for all types +2. `fp8 x {int2, uint2}` case ## Limitations @@ -36,4 +41,4 @@ We are currently optimizing the following cases: * Optimizations for memory bound cases. -* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size. \ No newline at end of file +* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size. diff --git a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp index fdb31316..55de3fab 100644 --- a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp @@ -151,16 +151,16 @@ void mixed_dtype_profiling( runtimes.reserve(options.iterations); for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { - cudaEventRecord(start); - CUTLASS_CHECK(gemm.run()); - cudaEventRecord(stop); - cudaEventSynchronize(stop); + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); - if (iter >= options.warmup) { - float milliseconds = 0; - cudaEventElapsedTime(&milliseconds, start, stop); - runtimes.push_back(milliseconds); - } + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } } cudaEventDestroy(start); diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp index 7f701126..bd71e9cf 100644 --- a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -33,6 +33,9 @@ #include + +#include "cutlass/util/device_memory.h" +#include "cutlass/integer_subbyte.h" #include "cutlass/float8.h" #include "cutlass/util/reference/device/tensor_fill.h" @@ -197,7 +200,6 @@ bool initialize_packed_scale( { cutlass::packed_scale_t tmp(data_in[i]); data_out[i] = reinterpret_cast const&>(tmp); - // std::cout << data_in[i] << ":" << std::hex << static_cast(data_in[i].storage) << ",\t" << -data_in[i] << ":" << std::hex << static_cast((-data_in[i]).storage) << std::endl; } try { block_out.copy_from_host(data_out.data()); @@ -207,4 +209,4 @@ bool initialize_packed_scale( return false; } return true; -} \ No newline at end of file +} diff --git a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp index 97df3b87..de5a3d3f 100644 --- a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp @@ -159,4 +159,4 @@ void reorder_tensor( cutlass::DeviceAllocation temp(size(layout_src)); reorder_tensor(data, layout_src, temp.get(), layout_dst); cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); -} \ No newline at end of file +} diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index d57e1dee..7b20a335 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -63,7 +63,7 @@ #include #include #include -#include +#include #include "cutlass/cutlass.h" diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp index 57365a8b..bfb64820 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp @@ -35,9 +35,36 @@ #include "dispatch_policy_extra.hpp" #include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" +#include "../pipeline/prefetch_pipeline_sm90.hpp" namespace cutlass::gemm::collective { +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size + constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); + constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes; +} + +} // namespace detail + // GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch template < class ElementA, @@ -98,7 +125,7 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; @@ -184,7 +211,7 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp index 710224d7..9bcb1f5a 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -57,6 +57,19 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +constexpr int PrefetchStages = 4; +constexpr int PrefetchInitialStages = 1; +// This determines how much shmem we set aside for prefetch. +// We don't reuse anything loaded by prefetcher, so we can keep +// loading into the same place -- there will be a conflict when +// writing, but it doesn't affect performance as much as the doors +// that this opens. +constexpr int PrefetchStagesActual = 1; + +} // namespace detail + // WarpSpecialized Mainloop template < int Stages, @@ -117,15 +130,7 @@ struct CollectiveMma< static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1"); using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); - static constexpr int PrefetchStages = 4; - static constexpr int PrefetchInitialStages = 1; - // This determines how much shmem we set aside for prefetch. - // We don't reuse anything loaded by prefetcher, so we can keep - // loading into the same place -- there will be a conflict when - // writing, but it doesn't affect performance as much as the doors - // that this opens. - static constexpr int PrefetchStagesActual = 1; - using PrefetcherPipeline = cutlass::PrefetchPipeline; + using PrefetcherPipeline = cutlass::PrefetchPipeline; using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineState = cutlass::PipelineState; @@ -155,7 +160,7 @@ struct CollectiveMma< using PrefetchSmemLayoutA = decltype(make_layout(make_shape( cute::Int(SmemLayoutA{})>{}, cute::Int(SmemLayoutA{})>{}, - cute::Int{}))); + cute::Int{}))); static constexpr auto prefetch_smem_size = cute::cosize_v; @@ -176,7 +181,7 @@ struct CollectiveMma< using InternalElementB = cute::conditional_t>>; // Defined outside the class where it's used, to work around MSVC issues - using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; + using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; struct SharedStorage { struct TensorStorage : cute::aligned_struct<128, _0> { @@ -660,7 +665,7 @@ struct CollectiveMma< bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0; float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio; int prefetch_iters = static_cast(static_cast(k_tile_count) * 0.5 * prefetch_ratio); - prefetch_iters = min(k_tile_count, ((prefetch_iters + PrefetchStages - 1) / PrefetchStages) * PrefetchStages); + prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages); Tensor sA = make_tensor( make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) @@ -702,7 +707,7 @@ struct CollectiveMma< break; } - prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= PrefetchStages); + prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages); using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType; BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage); diff --git a/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt b/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt new file mode 100644 index 00000000..18320259 --- /dev/null +++ b/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 64_ada_fp8_gemm_grouped + ada_fp8_gemm_grouped.cu + ) diff --git a/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu new file mode 100644 index 00000000..8e3dbbb0 --- /dev/null +++ b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu @@ -0,0 +1,1208 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Ada FP8 GEMM Grouped With Per-Group Scale Example. + + This workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices + in Global Memory are passed to the kernel in array (also held in Global Memory). Similarly, + leading dimensions and problem sizes are stored in arrays in GMEM. + + This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM + concept may be distinct. + + The differences between this and the examples/24_gemm_grouped are: (1) this example scales the output of each GEMM by a different scalar value specified by alpha_ptr_array. (2) this example uses FP8 tensorcore. + + This benchmark program initializes a workspace with random problem sizes for a given number of + groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to + model problems more similar to the traditional batched GEMM. + + Additionally, problem sizes are collected and binned to compute the same problem as a series of + conventional batched GEMMs (setup for this problem is not timed). This demonstrates the performance + enhancement achieved by implementing a specialized grouped GEMM kernel. + + Examples: + + # Runs a grouped GEMM with 100 random problem sizes + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 + + # Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024) + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --k=1024 --verbose=true + + # Runs a grouped GEMM that is equivalent to a batched GEMM + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true + + # Execute Grouped GEMM and profile with NSight + $ nv-nsight-cu-cli ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --m=256 --n=256 --k=256 --verbose=true \ + --iterations=1 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double initialization_time_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double initialization_time_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), + status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for cutlass::gemm::GemmCoord +struct HashGemmCoord { + size_t operator()(cutlass::gemm::GemmCoord const &problem) const { + std::hash hasher; + return (hasher(problem.m() * 3)) ^ (hasher(1 + problem.n() * 5)) ^ (hasher(2 + problem.k() * 7)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool profile_initialization; + bool sort_problems; + + std::vector problem_sizes; + + // problem size bins + std::unordered_map< + cutlass::gemm::GemmCoord, + std::vector, + HashGemmCoord> problem_bins; + + int alignment; + int problem_count; + int iterations; + int cuda_streams; + bool verbose; + float alpha; + std::vector alpha_array; + float beta; + std::string benchmark_path; + + std::string output_tag; + std::ofstream output_file; + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + std::vector scheduler_modes; + + std::unordered_map + str_to_scheduler_mode = { + {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, + {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} + }; + + struct GroupScheduleModeHash { + size_t operator()(GroupScheduleMode m) const { + return static_cast(m); + } + }; + + std::unordered_map + scheduler_mode_to_str = { + {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, + {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} + }; + + std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(16), + reference_check(true), + profile_initialization(false), + sort_problems(false), + problem_count(15), + iterations(20), + cuda_streams(0), + verbose(false), + alpha(1), + beta(), + scheduler_modes({GroupScheduleMode::kDeviceOnly}) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 16); + cmd.get_cmd_line_argument("groups", problem_count, 15); + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("verbose", verbose, false); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); + cmd.get_cmd_line_argument("sort-problems", sort_problems, false); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + scheduler_modes.clear(); + if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { + scheduler_modes = all_scheduler_modes; + } else { + for (std::string precomp_str : scheduler_mode_strs) { + auto it = str_to_scheduler_mode.find(precomp_str); + if (it != str_to_scheduler_mode.end()) { + scheduler_modes.push_back(it->second); + } else if (precomp_str == "all") { + std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; + error = true; + return; + } else { + std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; + error = true; + return; + } + } + } + } + + std::string output_path; + cmd.get_cmd_line_argument("tag", output_tag); + cmd.get_cmd_line_argument("output_file", output_path); + + if (!output_path.empty()) { + + std::ios_base::openmode open_mode = std::ios_base::out; + + std::ifstream input_file(output_path.c_str()); + + if (input_file.good()) { + open_mode = std::ios_base::app; + input_file.close(); + } + + output_file.open(output_path.c_str(), open_mode); + + if (output_file.good() && open_mode != std::ios_base::app) { + output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; + } + } + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + error = true; + problem_sizes.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + // Post-process the problem sizes + bin_problems(); + + // Initalize alpha array + randomize_alpha_ptr_array(cmd); + } + + void randomize_problems(cutlass::CommandLine &cmd) { + + // + // For now, randomly choose the problem sizes. + // + + int cmd_line_m = -1; + int cmd_line_n = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes.reserve(problem_count); + + for (int i = 0; i < problem_count; ++i) { + + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + + if (m < 1) { + m = alignment * ((rand() % 256) + 1); + } + + if (n < 1) { + n = alignment * ((rand() % 256) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 256) + 1); + } + + cutlass::gemm::GemmCoord problem(m, n, k); + + problem_sizes.push_back(problem); + } + } + + void randomize_alpha_ptr_array(cutlass::CommandLine &cmd) { + alpha_array.resize(problem_count); + for (int i = 0; i < problem_count; ++i) { + alpha_array[i] = static_cast((rand() % 100) - 50 + alpha); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes.push_back(extent); + } + } + + return true; + } + + /// Post processes the problems + void bin_problems() { + + problem_bins.clear(); + + problem_count = int(problem_sizes.size()); + + // + // Insert the problem sizes into a sorted container class. This is *NOT* necessary + // to run the CUTLASS kernel, but it enables the execution of cublas's batched GEMM. + // + for (int i = 0; i < int(problem_sizes.size()); ++i) { + auto it = problem_bins.find(problem_sizes.at(i)); + if (it == problem_bins.end()) { + problem_bins.insert({problem_sizes.at(i), std::vector({i}) }); + } + else { + it->second.push_back(i); + } + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "64_ada_fp8_gemm_grouped\n\n" + << " This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM\n" + << " in that multiple, independent GEMMs are computed by one grid launch. It differs in that each\n" + << " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n" + << " in device Global Memory and loaded by the kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --benchmark= Executes a benchmark problem size.\n" + << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" + << " --tag= String tag to prepend to the CSV file.\n" + << " --groups= Number of individual GEMM problems (default: --groups=15)\n" + << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" + << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n" + << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n" + << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" + << " --sort-problems= If true, sorts problem sizes in descending order of GEMM-K dimension.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a grouped GEMM with 100 random problem sizes\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100\n\n" + + << "# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped GEMM that is equivalent to a batched GEMM\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --scheduler-modes=all\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --scheduler-modes=all --profile-initialization=true\n\n" + + << "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n" + << "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n" + << "#\n" + << "# For example, assume the following are the contents of 'problems.txt'\n" + << "#\n" + << "# 0 1024x256x520\n" + << "# 1 520x264x1024\n" + << "# 2 96x48x1024\n" + << "#\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --benchmark=problems.txt\n\n" + + << "# Execute Grouped GEMM and profile with NSight\n" + << "$ nv-nsight-cu-cli ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --m=256 --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + for (auto const & problem : problem_sizes) { + fmas += problem.product(); + } + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BaseTestbed { +public: + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + std::vector alpha_ptr_array_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation alpha_array_device; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation alpha_ptr_array_device; + + BaseTestbed( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + int problem_count() const { + return options.problem_count; + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = static_cast(5); + scope_min = static_cast(-5); + } + else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Allocates device-side data + void allocate() { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + lda_host.resize(problem_count()); + ldb_host.resize(problem_count()); + ldc_host.resize(problem_count()); + ldd_host.resize(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + + auto problem = options.problem_sizes.at(i); + + lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.m() * problem.k(); + int64_t elements_B = problem.k() * problem.n(); + int64_t elements_C = problem.m() * problem.n(); + int64_t elements_D = problem.m() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + } + + lda.reset(problem_count()); + ldb.reset(problem_count()); + ldc.reset(problem_count()); + ldd.reset(problem_count()); + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + alpha_ptr_array_host.resize(problem_count()); + alpha_array_device.reset(problem_count()); + alpha_ptr_array_device.reset(problem_count()); + } + + /// Initializes device-side data + void initialize() { + problem_sizes_device.reset(problem_count()); + problem_sizes_device.copy_from_host(options.problem_sizes.data()); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(problem_count()); + std::vector ptr_B_host(problem_count()); + std::vector ptr_C_host(problem_count()); + std::vector ptr_D_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count()); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count()); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count()); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count()); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); + initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); + initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); + + cutlass::reference::device::BlockFillSequential( + block_D.get(), block_D.size(), ElementC(), ElementC()); + + // Initialize alpha array + alpha_array_device.copy_from_host(options.alpha_array.data()); + for (int32_t i = 0; i < problem_count(); ++i) { + alpha_ptr_array_host.at(i) = alpha_array_device.get() + i; + } + alpha_ptr_array_device.copy_from_host(alpha_ptr_array_host.data()); + } + + /// Verifies the result is a GEMM + bool verify() { + + bool passed = true; + + for (int32_t i = 0; i < problem_count(); ++i) { + cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + cutlass::TensorView view_A(block_A.get() + offset_A.at(i), layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + offset_B.at(i), layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + offset_C.at(i), layout_C, extent_C); + + cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + options.alpha_array[i], + view_A, + Gemm::kTransformA, + view_B, + Gemm::kTransformB, + options.beta, + view_C, + view_Ref_device, + ElementAccumulator(0) + ); + + // Copy to host memory + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); + + cutlass::TensorView view_D( matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference check + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + return passed; + } + } + + return passed; + } + +}; + +template +class TestbedGrouped : BaseTestbed { +public: + TestbedGrouped( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + // Redefine GEMM with different GroupScheduleMode_ + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGroupedPerGroupScale< + typename Gemm_::ElementA, + typename Gemm_::LayoutA, + Gemm_::kTransformA, + Gemm_::kAlignmentA, + typename Gemm_::ElementB, + typename Gemm_::LayoutB, + Gemm_::kTransformB, + Gemm_::kAlignmentB, + typename Gemm_::ElementC, + typename Gemm_::LayoutC, + typename Gemm_::ElementAccumulator, + typename Gemm_::OperatorClass, + typename Gemm_::ArchTag, + typename Gemm_::ThreadblockShape, + typename Gemm_::WarpShape, + typename Gemm_::InstructionShape, + typename Gemm_::EpilogueOutputOp, + typename Gemm_::ThreadblockSwizzle, + Gemm_::kStages, + GroupScheduleMode_>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmGrouped; + + /// Verbose printing of problem sizes + void print_problem_sizes() { + std::cout << std::endl; + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = Gemm::problem_tile_count(problem); + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Sort problems in descending order of problem-K dimension + void sort_problems() { + Gemm::sort_problems(this->options.problem_count, + this->options.problem_sizes.data(), + this->lda_host.data(), + this->ldb_host.data(), + this->ldc_host.data(), + this->ldd_host.data(), + this->offset_A.data(), + this->offset_B.data(), + this->offset_C.data(), + this->offset_D.data()); + } + + /// Executes a grouped kernel and measures runtime + Result profile() { + std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; + + std::cout << std::endl; + std::cout << "Grouped GEMM (CUTLASS) with mode " << sched_mode << ":\n" + << "====================================================" << std::endl; + + Result result; + + int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + this->allocate(); + if (this->options.sort_problems) { + sort_problems(); + } + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // Configure the GEMM arguments + typename Gemm::EpilogueOutputOp::ElementCompute ** alpha_ptr_array = this->alpha_ptr_array_device.get(); + typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha_ptr_array, nullptr); + + // Configure GEMM arguments + typename Gemm::Arguments args( + this->problem_sizes_device.get(), + this->problem_count(), + threadblock_count, + epilogue_op, + this->ptr_A.get(), + this->ptr_B.get(), + this->ptr_C.get(), + this->ptr_D.get(), + this->lda.get(), + this->ldb.get(), + this->ldc.get(), + this->ldd.get(), + this->options.problem_sizes.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = gemm.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Run the grouped GEMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (this->options.reference_check) { + result.passed = this->verify(); + } + + // + // Warm-up run of the grouped GEMM object + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + // Optionally profile initialization + if (this->options.profile_initialization) { + // Warm up + gemm.initialize(args, workspace.get()); + + auto start_time = std::chrono::high_resolution_clock::now(); + for (int32_t i = 0; i < this->options.iterations; ++i) { + gemm.initialize(args, workspace.get()); + } + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + duration /= double(this->options.iterations); + result.initialization_time_ms = duration.count(); + } + + int64_t total_tiles = Gemm::group_tile_count(args); + std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; + + std::cout << std::endl; + std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; + if (this->options.profile_initialization) { + std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; + } + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," + << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + std::cout << "\nPassed\n"; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { + std::cerr << "This example requires CUDA 12.4 or greater." << std::endl; + return 0; + } + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() failed with error: " << cudaGetErrorString(result) << std::endl; + return 0; + } + + if (!(properties.major == 8 && properties.minor == 9)) { + std::cerr << "CUTLASS's Ada FP8 Gemm Grouped example requires a device of compute capability 89.\n" << std::endl; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Define the Grouped and Batched GEMM types + // + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + constexpr int ElementsPerAccessB = 128 / cutlass::sizeof_bits::value; + + // Define a grouped GEMM kernel with all template parameters set except + // for scheduling mode. This will be used as the template for all scheduling + // modes executed. + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGroupedPerGroupScale< + ElementA, + LayoutA, + cutlass::ComplexTransform::kNone, + ElementsPerAccessA, + ElementB, + LayoutB, + cutlass::ComplexTransform::kNone, + ElementsPerAccessB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm89, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 4>::GemmKernel; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + // + // Profile it + // + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + for (GroupScheduleMode mode : options.scheduler_modes) { + Result result; + switch (mode) { + case GroupScheduleMode::kDeviceOnly: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + case GroupScheduleMode::kHostPrecompute: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + } + + if (result.error != cudaSuccess) { + return 1; + } + + // Override verbose flag to avoid printing duplicate information for each scheduling mode + options.verbose = false; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 6486d714..7e8d4522 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -143,8 +143,10 @@ foreach(EXAMPLE 61_hopper_gemm_with_topk_and_softmax 62_hopper_sparse_gemm 63_hopper_gemm_with_weight_prefetch + 64_ada_fp8_gemm_grouped ) add_subdirectory(${EXAMPLE}) endforeach() + diff --git a/examples/cute/tutorial/tiled_copy.cu b/examples/cute/tutorial/tiled_copy.cu index 87ad873c..a8ae3b10 100644 --- a/examples/cute/tutorial/tiled_copy.cu +++ b/examples/cute/tutorial/tiled_copy.cu @@ -95,36 +95,17 @@ __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout) /// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation /// has the precondition that pointers are aligned to the vector size. /// -template -__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout) +template +__global__ void copy_kernel_vectorized(TensorS S, TensorD D, Tiled_Copy tiled_copy) { using namespace cute; - using Element = typename TensorS::value_type; // Slice the tensors to obtain a view into each tile. Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) - // Define `AccessType` which controls the size of the actual memory access. - using AccessType = cutlass::AlignedArray; - - // A copy atom corresponds to one hardware memory access. - using Atom = Copy_Atom, Element>; - - // Construct tiled copy, a tiling of copy atoms. - // - // Note, this assumes the vector and thread layouts are aligned with contigous data - // in GMEM. Alternative thread layouts are possible but may result in uncoalesced - // reads. Alternative vector layouts are also possible, though incompatible layouts - // will result in compile time errors. - auto tiled_copy = - make_tiled_copy( - Atom{}, // access size - ThreadLayout{}, // thread layout - VecLayout{}); // vector layout (e.g. 4x1) - // Construct a Tensor corresponding to each thread's slice. - auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x); Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN) Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN) @@ -198,11 +179,34 @@ int main(int argc, char** argv) Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n') Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n') - // Thread arrangement - Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{})); + // Construct a TiledCopy with a specific access pattern. + // This version uses a + // (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc), + // (2) Layout-of-Values that each thread will access. - // Vector dimensions - Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); + // Thread arrangement + Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{})); // (32,8) -> thr_idx + + // Value arrangement per thread + Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx + + // Define `AccessType` which controls the size of the actual memory access instruction. + using CopyOp = UniversalCopy>; // A very specific access width copy instruction + //using CopyOp = UniversalCopy>; // A more generic type that supports many copy strategies + //using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs + + // A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element. + using Atom = Copy_Atom; + + // Construct tiled copy, a tiling of copy atoms. + // + // Note, this assumes the vector and thread layouts are aligned with contigous data + // in GMEM. Alternative thread layouts are possible but may result in uncoalesced + // reads. Alternative value layouts are also possible, though incompatible layouts + // will result in compile time errors. + TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy + thr_layout, // thread layout (e.g. 32x4 Col-Major) + val_layout); // value layout (e.g. 4x1) // // Determine grid and block dimensions @@ -217,8 +221,7 @@ int main(int argc, char** argv) copy_kernel_vectorized<<< gridDim, blockDim >>>( tiled_tensor_S, tiled_tensor_D, - thr_layout, - vec_layout); + tiled_copy); cudaError result = cudaDeviceSynchronize(); if (result != cudaSuccess) { diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index 9d080116..c9e02245 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -51,19 +51,14 @@ naive_cooperative_copy(uint32_t const& tid, Tensor const& src, Tensor & dst) { - auto N = size(src); - if (tid < N) { - uint32_t upper_bound = (N / NumThreads) * NumThreads; - CUTE_UNROLL - for (uint32_t i = 0; i < upper_bound; i += NumThreads) { // All in-bounds - dst[tid + i] = src[tid + i]; - } - if (N % NumThreads != 0) { // Likely static condition - uint32_t final_idx = tid + upper_bound; - if (final_idx < N) { // Final in-bounds - dst[final_idx] = src[final_idx]; - } - } + auto N = size(dst); + auto R = N % Int{}; + if (R > 0 && tid < R) { // Likely static condition && Residue in-bounds + dst[tid] = src[tid]; + } + CUTE_UNROLL + for (uint32_t i = uint32_t(R); i < uint32_t(N); i += NumThreads) { // All in-bounds + dst[tid + i] = src[tid + i]; } } @@ -117,12 +112,14 @@ heuristic_permutation(Tensor const& a, // template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor & dst) + Tensor & dst, + CopyPolicy const& cpy = {}) { // Assumes the shapes are static, can generalize/fallback CUTE_STATIC_ASSERT_V(is_static{} && is_static{}); @@ -283,23 +280,28 @@ cooperative_copy(uint32_t const& tid, // If we're using all threads (static) or the tid is in-range (dynamic) if (vec_thrs == NumThreads or tid < vec_thrs) { - return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); + auto src_c = recast(src_v); + auto dst_c = recast(dst_v); + return copy(cpy, src_c, dst_c); } } } + // Default max-vectorization size to value_type size template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor & dst) + Tensor & dst, + CopyPolicy const& cpy = {}) { constexpr uint32_t MaxVecBits = sizeof_bits_v; - return cooperative_copy(tid, src, dst); + return cooperative_copy(tid, src, dst, cpy); } // @@ -308,26 +310,30 @@ cooperative_copy(uint32_t const& tid, template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor && dst) + Tensor && dst, + CopyPolicy const& cpy = {}) { - return cooperative_copy(tid, src, dst); + return cooperative_copy(tid, src, dst, cpy); } template + class DstEngine, class DstLayout, + class CopyPolicy = DefaultCopy> CUTE_HOST_DEVICE void cooperative_copy(uint32_t const& tid, Tensor const& src, - Tensor && dst) + Tensor && dst, + CopyPolicy const& cpy = {}) { - return cooperative_copy(tid, src, dst); + return cooperative_copy(tid, src, dst, cpy); } } // end namespace cute diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index 2c91ce6f..e4bd5ea6 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -50,31 +50,115 @@ namespace cute namespace detail { -// Predicated Cooperative GEMM -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> +// Slow fallback path: +template CUTE_HOST_DEVICE void -cooperative_gemm_predication(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +epilogue_predication(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & sC, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C { - using TypeA = typename TA::value_type; - using TypeB = typename TB::value_type; - using TypeC = typename TC::value_type; + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename ThrMMA::ValTypeC; + CUTE_STATIC_ASSERT(CUTE_STL_NAMESPACE::is_same_v); + + // Create coordinate tensors for the problem + Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) + // Repeat partitioning with thr_mma + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + // Custom axpby_if for now + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) + { + if (elem_less(tCcC(i), shape(sC))) + { + tCsC(i) = sC_store_op(isBetaZero ? alpha * tCrC(i) + : alpha * tCrC(i) + + beta * static_cast(sC_load_op(tCsC(i)))); + } + } +} + +template +CUTE_HOST_DEVICE +void +epilogue_no_predication(Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C + SmemCopyOpC const& sC_copy_op) +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename TRC::value_type; + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + Tensor tCrDi = make_fragment_like(tCsC); + Tensor tCrD = make_fragment_like(tCrC); + if(!isBetaZero) { + copy(sC_copy_op, tCsC, tCrDi); + // Transform C on/after load + cute::transform(tCrDi, tCrD, sC_load_op); + } + // C = alpha * (A * B) + beta * C + axpby(alpha, tCrC, beta, tCrD); + // Transform C before/on store + cute::transform(tCrD, tCrDi, sC_store_op); + copy(sC_copy_op, tCrDi, tCsC); +} + +// Predicated Cooperative GEMM +template +CUTE_HOST_DEVICE +void +cooperative_gemm_predication(ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op) // transforms B values before use in GEMM +{ + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; // // MMA Partitioning @@ -83,22 +167,18 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, // Partition the sA, sB, and sC tiles across the threads for the MMA Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) - Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) - Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) #if 0 if (thread0()) { print(" sA: "); print( sA); print("\n"); print(" sB: "); print( sB); print("\n"); - print(" sC: "); print( sC); print("\n"); print(thr_mma); print("tCsA: "); print(tCsA); print("\n"); print("tCsB: "); print(tCsB); print("\n"); - print("tCsC: "); print(tCsC); print("\n"); print("tCrA: "); print(tCrA); print("\n"); print("tCrB: "); print(tCrB); print("\n"); print("tCrC: "); print(tCrC); print("\n"); @@ -154,23 +234,20 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M CUTE_UNROLL for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I - tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; + tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,0))) : ComputeTypeA{}; } } CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N CUTE_UNROLL for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I - tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; + tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,0))) : ComputeTypeB{}; } } // // MAINLOOP // - // Clear accumulators - clear(tCrC); - CUTE_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { @@ -185,138 +262,80 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M CUTE_UNROLL for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I - tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; + tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,k_next))) : ComputeTypeA{}; } } CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N CUTE_UNROLL for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I - tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; + tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,k_next))) : ComputeTypeB{}; } } } // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } - - // - // Epilogue - // - - // Create coordinate tensors for the problem - Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) - // Repeat partitioning with thr_mma - Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) - - const bool isBetaZero = (beta == Beta{}); - - // Custom axpby_if for now - CUTE_UNROLL - for (int i = 0; i < size(tCrC); ++i) - { - if (elem_less(tCcC(i), shape(sC))) - { - tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast(tCrC(i)) - : alpha * static_cast(tCrC(i)) + - beta * static_cast(sC_load_op(tCsC(i)))); - } - } -} - -// Slow fallback path -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> -CUTE_HOST_DEVICE -void -cooperative_gemm_predication(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C -{ - // ThrMMA - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op); } // Unpredicated Cooperative GEMM -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class SmemCopyOpA, class SmemCopyOpB> CUTE_HOST_DEVICE void -cooperative_gemm_no_predication(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +cooperative_gemm_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op, + SmemCopyOpB const& sB_copy_op) { - using TypeA = typename TA::value_type; - using TypeB = typename TB::value_type; - using TypeC = typename TC::value_type; + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; - // ThrMMA - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); // // MMA Partitioning // - Tensor tCsC = thr_mma.partition_C(sC); // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) - Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) using CopyOpAType = SmemCopyOpA; using CopyOpBType = SmemCopyOpB; - auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); + auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); Tensor tCsA = smem_thr_copy_A.partition_S(sA); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + Tensor tCrAi = make_fragment_like(tCsA); + Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K - auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); + auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); Tensor tCsB = smem_thr_copy_B.partition_S(sB); - Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + Tensor tCrBi = make_fragment_like(tCsB); + Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K #if 0 if (thread0()) { print(" sA: "); print(sA); print("\n"); print(" sB: "); print(sB); print("\n"); - print(" sC: "); print(sC); print("\n"); print(thr_mma); print("\n"); - print("tCsC: "); print(tCsC); print("\n"); print("tCrA: "); print(tCrA); print("\n"); print("tCrB: "); print(tCrB); print("\n"); print("tCrC: "); print(tCrC); print("\n"); @@ -333,15 +352,12 @@ cooperative_gemm_no_predication(uint32_t thread_idx, // PREFETCH // - copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); - copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrAi_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrBi_copy_view(_,_,Int<0>{})); // // MAINLOOP // - // Clear accumulators - clear(tCrC); - constexpr int K_BLOCK_MAX = size<2>(tCrA); CUTE_UNROLL @@ -352,132 +368,178 @@ cooperative_gemm_no_predication(uint32_t thread_idx, { // Load the next k_block int k_next = k_block + 1; // statically unrolled - copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next)); - copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next)); + copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrAi_copy_view(_,_,k_next)); + copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrBi_copy_view(_,_,k_next)); } // Transform A and B, relying on the compiler to remove in case of identity ops - cute::transform(tCrA(_,_,k_block), sA_load_op); - cute::transform(tCrB(_,_,k_block), sB_load_op); + cute::transform(tCrAi(_,_,k_block), tCrA(_,_,k_block), sA_load_op); + cute::transform(tCrBi(_,_,k_block), tCrB(_,_,k_block), sB_load_op); // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } - - // - // Epilogue - // - - auto isBetaZero = [&] () { - if constexpr (is_complex::value) { - return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; - } - else { - return beta == Int<0>{}; - } - CUTE_GCC_UNREACHABLE; - } (); - - using CopyOpCType = SmemCopyOpC; - Tensor tCrD = thr_mma.make_fragment_C(tCsC); - if(!isBetaZero) { - copy(CopyOpCType{}, tCsC, tCrD); - // Transform C on/after load - cute::transform(tCrD, sC_load_op); - } - // C = alpha * (A * B) + beta * C - axpby(alpha, tCrC, beta, tCrD); - // Transform C before/on store - cute::transform(tCrD, sC_store_op); - copy(CopyOpCType{}, tCrD, tCsC); } } // end namespace detail -template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> -CUTE_HOST_DEVICE -void -cooperative_gemm(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C -{ - CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM - CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN - CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK - - using TypeA = typename TA::value_type; - using TypeB = typename TB::value_type; - using TypeC = typename TC::value_type; - - static_assert(is_convertible_v>, TypeA>, - "ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type"); - static_assert(is_convertible_v>, TypeB>, - "BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type"); - static_assert(is_convertible_v>, TypeC>, - "CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); - static_assert(is_convertible_v>, TypeC>, - "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); - - static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), - tile_shape(TiledMMA{})); - if constexpr (compat) { - detail::cooperative_gemm_no_predication( - thread_idx, tiled_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op - ); - } else { - detail::cooperative_gemm_predication( - thread_idx, tiled_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op - ); - } -} - +// C passed as a shared memory tensor +// Epilogue included template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class SmemCopyOpA = DefaultCopy, class SmemCopyOpB = DefaultCopy, + class SmemCopyOpC = DefaultCopy> +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyOpC const& sC_copy_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) :: InputTypeC + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) :: ComputeTypeC + + // Clear accumulators + clear(tCrC); + +#if 0 + if (thread0()) { + print(" sC: "); print(sC); print("\n"); + print(" tCsC: "); print(tCsC); print("\n"); + } +#endif + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + detail::epilogue_no_predication( + alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); + } +} + +// C already partitioned into registers on input +// It can be passed non-empty +// Epilogue not included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + // Check if input C fragment is compatible with thr_mma and problem size + using ref_c_frag = decltype(partition_shape_C(tiled_mma, make_shape(size<0>(sA), size<0>(sB)))); + CUTE_STATIC_ASSERT_V(compatible(shape(ref_c_frag{}), shape(tCrC))); + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + } +} + +// Accept mutable temporaries +template CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, - TiledMMA const& tiled_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor && sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyOpC const& sC_copy_op = {}) { - using CopyOpA = AutoVectorizingCopyWithAssumedAlignment>; - using CopyOpB = AutoVectorizingCopyWithAssumedAlignment>; - using CopyOpC = AutoVectorizingCopyWithAssumedAlignment>; - cooperative_gemm( - thread_idx, tiled_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op - ); + cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op, + sA_copy_op, sB_copy_op, sC_copy_op); } // Legacy overload of cute::gemm for backwards-compatibility @@ -485,27 +547,38 @@ template ::value && - BLayout::rank == 2 && is_smem::value && - CLayout::rank == 2 && is_smem::value)> + class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity> CUTE_HOST_DEVICE void -gemm(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM - BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM - CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM - CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + // Goes directly to the slow path to avoid getting thread_idx from thr_mma detail::cooperative_gemm_predication( - thr_mma, alpha, sA, sB, beta, sC, - sA_load_op, sB_load_op, sC_load_op, sC_store_op + thr_mma, sA, sB, sC, sA_load_op, sB_load_op + ); + + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op ); } diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index c2decd15..84ef4916 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -38,79 +38,6 @@ namespace cute { -// -// Accept mutable temporaries -// - -template -CUTE_HOST_DEVICE -void -copy(Tensor const& src, - Tensor && dst) -{ - return copy(src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_vec(Tensor const& src, - Tensor && dst) -{ - return copy_vec(src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_aligned(Tensor const& src, - Tensor && dst) -{ - return copy_aligned(src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_if(PrdTensor const& pred, - Tensor const& src, - Tensor && dst) -{ - return copy_if(pred, src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy_if(CopyPolicy const& copy_policy, - PrdTensor const& pred, - Tensor const& src, - Tensor && dst) -{ - return copy_if(copy_policy, pred, src, dst); -} - -template -CUTE_HOST_DEVICE -void -copy(CopyPolicy const& copy_policy, - Tensor const& src, - Tensor && dst) -{ - return copy(copy_policy, src, dst); -} - // // copy_if -- Predicated Copy // @@ -124,12 +51,13 @@ copy_if(PrdTensor const& pred, Tensor const& src, Tensor & dst) { - auto copy_op = select_elementwise_copy(src, dst); + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; CUTE_UNROLL - for (int i = 0; i < size(src); ++i) { + for (int i = 0; i < size(dst); ++i) { if (pred(i)) { - copy_op.copy(src(i), dst(i)); + dst(i) = static_cast(static_cast(src(i))); } } } @@ -138,17 +66,6 @@ copy_if(PrdTensor const& pred, // copy_if -- Predicated CopyAtom // -namespace detail { - -// Trait that detects if atom's traits has a member function with(bool) -template -constexpr bool has_with_bool = false; - -template -constexpr bool has_with_bool().with(declval()))>> = true; - -} // end namespace detail - template const& copy_atom, Tensor & dst) // (V,Rest...) { static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy - copy_atom.call(src, dst); + if constexpr (has_with_bool) { + copy_atom.with(pred()).call(src, dst); + } else { + if (pred()) { copy_atom.call(src, dst); } + } } else { // Loop over all but the first mode constexpr int R = SrcLayout::rank; Tensor src_v = group_modes<1,R>(src); Tensor dst_v = group_modes<1,R>(dst); CUTE_UNROLL - for (int i = 0; i < size<1>(src_v); ++i) { - // If copy traits can be transformed with a predicate value, do it, otherwise branch here - if constexpr (detail::has_with_bool>) { + for (int i = 0; i < size<1>(dst_v); ++i) { + if constexpr (has_with_bool) { copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i)); } else { - if (pred(i)) { - copy_atom.call(src_v(_,i), dst_v(_,i)); - } + if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); } } } } } // -// copy_vec -- attempt vectorized copy with VecType +// copy_if -- AutoCopyAsync // - -template CUTE_HOST_DEVICE void -copy_vec(Tensor const& src, - Tensor & dst) +copy_if(AutoCopyAsync const& cpy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) { - static_assert(sizeof_bits_v >= 8 && sizeof_bits_v % 8 == 0, - "Expected a vectorization type of at least a byte."); + using SrcElemWithConst = remove_reference_t; using SrcType = typename SrcEngine::value_type; using DstType = typename DstEngine::value_type; - if constexpr (cute::is_same::value && - sizeof_bits_v > sizeof_bits_v) - { - // Preserve volatility of Src/Dst types. - using SrcVecType = conditional_t, VecType const volatile, VecType const>; - using DstVecType = conditional_t, VecType volatile, VecType >; - Tensor src_v = recast(src); - Tensor dst_v = recast(dst); -#if 0 - if (thread0()) { - print("copy_vec<%db> -- vectorizing copy:\n", int(sizeof_bits_v)); - print(" "); print(src); print(" => "); print(src_v); print("\n"); - print(" "); print(dst); print(" => "); print(dst_v); print("\n"); + auto copy_op = []() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (is_gmem::value && is_smem::value && + sizeof(SrcType) == sizeof(DstType)) { + if constexpr (is_const_v && sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEGLOBAL{}; + } else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEALWAYS{}; + } else { + return UniversalCopy{}; + } + } else { + return UniversalCopy{}; } -#endif - return copy_if(TrivialPredTensor{}, src_v, dst_v); - } else { -#if 0 - if (thread0()) { - print("copy_vec<%db> -- NOT vectorizing copy:\n", int(sizeof_bits_v)); - print(" "); print(src); print("\n"); - print(" "); print(dst); print("\n"); - } + CUTE_GCC_UNREACHABLE; +#else + return UniversalCopy{}; #endif + }(); - return copy_if(TrivialPredTensor{}, src, dst); + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + copy_op.copy(src(i), dst(i)); + } } } +// +// copy -- AutoCopyAsync +// + +template +CUTE_HOST_DEVICE +void +copy(AutoCopyAsync const& cpy, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + copy_if(cpy, TrivialPredTensor{}, src, dst); +} + // // copy -- CopyAtom // @@ -238,15 +172,56 @@ template const& copy_atom, - Tensor const& src, - Tensor & dst) + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) { - return copy_if(copy_atom, TrivialPredTensor{}, src, dst); + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + + if constexpr (is_static::value && is_static::value) { + CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v)); + + // AutoFilter on the Rest-mode + auto dst_null = nullspace(layout<1>(dst_v)); + + Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest)) + Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest)) + + CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n)); + CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error"); + CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy"); + CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{})); + CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{})); + + Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + + CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c)); + CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst)); + CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src)); + + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_c); ++i) { + copy_atom.call(src_c(_,i), dst_c(_,i)); + } + } else { + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } + } + } } -////////////////////////////////////////// -// Special Auto-Vectorizing Overloads -////////////////////////////////////////// +//////////////////////////////////////////////////////// +// Special Auto-Vectorizing, Auto-Filtering Overloads // +//////////////////////////////////////////////////////// // Specialization for AutoVectorizingCopyAssumedAlignment template const&, Tensor const& src, Tensor & dst) { - constexpr int vec_elem = decltype(max_common_vector(src, dst))::value; + constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst)); + constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); + static_assert(is_integral{} * sizeof_bits_v)>::value, "Error: Attempting a subbit copy!"); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); - constexpr int max_align_src = decltype(max_alignment(src.layout()))::value; - constexpr int max_align_dst = decltype(max_alignment(dst.layout()))::value; - constexpr int max_align = gcd(vec_elem, max_align_src, max_align_dst); + if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) { + // If more than one element vectorizes to 8bits or more, then recast and copy + using VecType = uint_bit_t; + // Preserve volatility + using SrcVecType = conditional_t, VecType const volatile, VecType const>; + using DstVecType = conditional_t, VecType volatile, VecType >; - constexpr int src_bits = sizeof_bits::value; - constexpr int vec_bits = gcd(src_bits * max_align, MaxVecBits); + // Recast + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); - if constexpr (vec_elem > 1 && vec_bits >= 8) { - // If more than one element vectorizes to 8bits or more, then copy_vec #if 0 if (thread0()) { - print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", vec_elem, vec_bits); - print(" "); print(src); print("\n"); - print(" "); print(dst); print("\n"); + print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits); + print(" "); print(src); print(" => "); print(src_v); print("\n"); + print(" "); print(dst); print(" => "); print(dst_v); print("\n"); } #endif - return copy_vec>(src, dst); + + return copy_if(TrivialPredTensor{}, src_v, dst_v); } else { return copy_if(TrivialPredTensor{}, src, dst); } } +template +struct AutoFilter { + Base const& base; + CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {} +}; + +// Specialization for AutoFilter +template +CUTE_HOST_DEVICE +void +copy(AutoFilter const& copy_op, + Tensor const& src, + Tensor & dst) +{ + if constexpr (is_constant::value) { + auto dst_null = nullspace(dst.layout()); + + Tensor dst_n = zipped_divide(dst, dst_null); + Tensor src_n = zipped_divide(src, dst_null); + + CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error"); + CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy"); + + copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_)); + } else { + copy(copy_op.base, src, dst); + } +} + // Auto-vectorizing copy for static layouts template @@ -292,7 +304,11 @@ copy(Tensor const& src, { if constexpr (is_static::value && is_static::value) { // Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned - return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned. + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst); } else { // Do not assume that dynamic layouts are aligned. return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst); @@ -307,7 +323,12 @@ void copy_aligned(Tensor const& src, Tensor & dst) { - return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else { + return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + } } // Specializaton for Atom AutoVectorizingCopyAssumedAlignment @@ -379,4 +400,146 @@ copy(Copy_Atom, CA_Args...> const& } #endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +// +// Decay TiledCopy to CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy_if(TiledCopy const& tiled_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + return copy_if(static_cast(tiled_copy), pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(TiledCopy const& tiled_copy, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast(tiled_copy), src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(ThrCopy const& thr_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) = delete; + +template +CUTE_HOST_DEVICE +void +copy(ThrCopy const& thr_copy, + Tensor const& src, + Tensor & dst) = delete; + +// +// Catch uncaught policies +// + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& cpy, + PredTensor const& prd, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& cpy, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& copy_policy, + PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(copy_policy, pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor && dst) +{ + return copy(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& copy_policy, + Tensor const& src, + Tensor && dst) +{ + return copy(copy_policy, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor && dst) +{ + return copy_aligned(src, dst); +} + } // end namespace cute diff --git a/include/cute/arch/copy.hpp b/include/cute/arch/copy.hpp index 51392899..47dbef2f 100644 --- a/include/cute/arch/copy.hpp +++ b/include/cute/arch/copy.hpp @@ -39,7 +39,7 @@ namespace cute { // -// Direct Copy for any type +// Direct Copy for any specific types // template @@ -48,21 +48,15 @@ struct UniversalCopy using SRegisters = S[1]; using DRegisters = D[1]; - template - CUTE_HOST_DEVICE static constexpr void - copy(S_ const& src, - D_ & dst) - { - dst = static_cast(static_cast(src)); - } + // Sanity + static_assert(sizeof_bits_v >= 8); + static_assert(sizeof_bits_v >= 8); - // Accept mutable temporaries - template CUTE_HOST_DEVICE static constexpr void - copy(S_ const& src, - D_ && dst) + copy(S const& src, + D & dst) { - UniversalCopy::copy(src, dst); + dst = src; } }; @@ -92,6 +86,12 @@ using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; using DefaultCopy = AutoVectorizingCopyWithAssumedAlignment<8>; +// +// Copy policy automatically selecting between +// UniversalCopy and cp.async , based on type and memory space. +// +struct AutoCopyAsync {}; + // // Global memory prefetch into L2 // diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp index 60777f22..17860dd4 100644 --- a/include/cute/arch/mma_sm80.hpp +++ b/include/cute/arch/mma_sm80.hpp @@ -2040,6 +2040,103 @@ struct SM80_16x8x64_S32U4U4S32_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// // MMA 8x8x128 TN @@ -2141,103 +2238,4 @@ struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC //////////////////////////////////////////////////////////////////////////////////////////////////// -// MMA 8x8x128 TN -struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC -{ - using DRegisters = uint32_t[2]; - using ARegisters = uint32_t[1]; - using BRegisters = uint32_t[1]; - using CRegisters = uint32_t[2]; - - CUTE_HOST_DEVICE static void - fma(uint32_t & d0, uint32_t & d1, - uint32_t const& a0, - uint32_t const& b0, - uint32_t const& c0, uint32_t const& c1) - { -#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) - asm volatile( - "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " - "{%0, %1}," - "{%2}," - "{%3}," - "{%4, %5};\n" - : "=r"(d0), "=r"(d1) - : "r"(a0), - "r"(b0), - "r"(c0), "r"(c1)); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// MMA 16x8x128 TN -struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC -{ - using DRegisters = uint32_t[4]; - using ARegisters = uint32_t[2]; - using BRegisters = uint32_t[1]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& a0, uint32_t const& a1, - uint32_t const& b0, - uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) - { -#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) - asm volatile( - "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) - : "r"(a0), "r"(a1), - "r"(b0), - "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// MMA 16x8x256 TN -struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC -{ - using DRegisters = uint32_t[4]; - using ARegisters = uint32_t[4]; - using BRegisters = uint32_t[2]; - using CRegisters = uint32_t[4]; - - CUTE_HOST_DEVICE static void - fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, - uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, - uint32_t const& b0, uint32_t const& b1, - uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) - { -#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) - asm volatile( - "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - "{%8, %9}," - "{%10, %11, %12, %13};\n" - : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), - "r"(b0), "r"(b1), - "r"(c0), "r"(c1), "r"(c2), "r"(c3)); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index dd6b4e52..75b7aa4d 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -100,16 +100,16 @@ struct Copy_Atom, CopyInternalType> if constexpr (is_constant::value || is_constant::value) { // Dispatch to unpack to execute instruction - return copy_unpack(*this, src, dst); - } else - if constexpr (is_tuple::value && - is_tuple::value) { + return copy_unpack(static_cast(*this), src, dst); + } else if constexpr (is_tuple::value && + is_tuple::value) { // If the size of the src/dst doesn't match the instruction, // recurse this rank-1 layout by peeling off the mode // ((A,B,C,...)) -> (A,B,C,...) return copy(*this, tensor<0>(src), tensor<0>(dst)); } else { - static_assert(dependent_false, "No instruction match and no recursion possible."); + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); } } diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index bfbeb4ea..ac746a64 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -92,23 +92,29 @@ struct Copy_Traits> using RefLayout = SrcLayout; }; +// Extract a CPY_Op from a CPY_Traits +template +struct CPY_Op {}; + +template +struct CPY_Op> { + using type = CPY_Op_Arg; +}; + // // Generic copy_unpack for common argument-based Copy_Traits // -template CUTE_HOST_DEVICE constexpr void -copy_unpack(Copy_Traits const&, - Tensor const& src, - Tensor & dst) +copy_unpack(AnyCPYTraits const&, + Tensor const& src, + Tensor & dst) { - // Specializations can generalize on these checks - //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); - //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); - + using CopyOp = typename CPY_Op::type; using RegistersSrc = typename CopyOp::SRegisters; using RegistersDst = typename CopyOp::DRegisters; using RegTypeSrc = typename remove_extent::type; @@ -129,18 +135,15 @@ copy_unpack(Copy_Traits const&, rD, make_int_sequence{}); } -// // Accept mutable temporaries -// - -template CUTE_HOST_DEVICE constexpr void -copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor && dst) +copy_unpack(AnyCPYTraits const& traits, + Tensor const& src, + Tensor && dst) { copy_unpack(traits, src, dst); } diff --git a/include/cute/atom/copy_traits_sm80.hpp b/include/cute/atom/copy_traits_sm80.hpp index e5ff0b7b..3795f52a 100644 --- a/include/cute/atom/copy_traits_sm80.hpp +++ b/include/cute/atom/copy_traits_sm80.hpp @@ -51,13 +51,6 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - - // Construct a zfill variant with a given predicate value - CUTE_HOST_DEVICE constexpr - Copy_Traits> - with(bool pred) const { - return {pred}; - } }; template @@ -73,13 +66,6 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - - // Construct a zfill variant with a given predicate value - CUTE_HOST_DEVICE constexpr - Copy_Traits> - with(bool pred) const { - return {pred}; - } }; template @@ -96,8 +82,15 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - // Predicate value that determines whether to load or zfill - bool pred = false; + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } // Overload copy_unpack for zfill variant to pass the predicate into the op template > // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - // Predicate value that determines whether to load or zfill - bool pred = false; + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } // Overload copy_unpack for zfill variant to pass the predicate into the op template > } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Element copy selector -template -CUTE_HOST_DEVICE constexpr -auto -select_elementwise_copy(SrcTensor const&, DstTensor const&) -{ - using SrcType = typename SrcTensor::value_type; - using DstType = typename DstTensor::value_type; - -#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) - if constexpr (is_gmem::value && is_smem::value && - sizeof(SrcType) == sizeof(DstType) && - (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16)) - { - return SM80_CP_ASYNC_CACHEALWAYS{}; - } else { - return UniversalCopy{}; - } - - CUTE_GCC_UNREACHABLE; -#else - return UniversalCopy{}; -#endif -} - -} +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 3738cc39..4ad7f808 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -58,37 +58,31 @@ struct AuxTmaParams { }; // Utility for unpacking TMA_LOAD arguments into a CopyOp -template +template struct TMA_LOAD_Unpack { - template CUTE_HOST_DEVICE friend constexpr void copy_unpack(Copy_Traits const& traits, Tensor const& src, Tensor & dst) { + static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); + auto src_coord = src.data().coord_; - if constexpr (detail::is_prefetch) { - return detail::explode_tuple(detail::CallCOPY{}, - traits.opargs_, tuple_seq{}, - src_coord, tuple_seq{}); - } else { - static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); - void* dst_ptr = cute::raw_pointer_cast(dst.data()); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); #if 0 - auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); - printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", - threadIdx.x, threadIdx.y, threadIdx.z, - blockIdx.x, blockIdx.y, blockIdx.z, - int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); #endif - return detail::explode_tuple(detail::CallCOPY{}, - traits.opargs_, tuple_seq{}, - make_tuple(dst_ptr), seq<0>{}, - src_coord, tuple_seq{}); - } + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); } }; @@ -131,7 +125,7 @@ struct Copy_Traits [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent - return {{}, {&tma_desc_, &tma_mbar, static_cast(cache_hint)}}; + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; } // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) @@ -143,7 +137,7 @@ struct Copy_Traits [[maybe_unused]] uint16_t const& multicast_mask = 0, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { // We accept multicast_mask here to keep the API for both atoms consistent - return {{}, {new_tma_desc, &tma_mbar, static_cast(cache_hint)}}; + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; } // Generate the TMA coord tensor @@ -167,7 +161,7 @@ struct Copy_Traits // The executable SM90_TMA_LOAD with tma_desc and tma_mbar template struct Copy_Traits - : TMA_LOAD_Unpack + : TMA_LOAD_Unpack { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -183,12 +177,15 @@ struct Copy_Traits uint64_t*, // smem mbarrier uint64_t // cache hint > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} }; // The prefetch for SM90_TMA_LOAD with tma_desc template struct Copy_Traits - : TMA_LOAD_Unpack { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -206,6 +203,19 @@ struct Copy_Traits CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const& traits) : opargs_({&traits.tma_desc_}) {} + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + auto src_coord = src.data().coord_; + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); + } }; ////////////////////////////////////////////////////////////////////////////// @@ -246,7 +256,7 @@ struct Copy_Traits uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { - return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + return {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; } // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) @@ -257,7 +267,7 @@ struct Copy_Traits uint64_t& tma_load_mbar, uint16_t const& multicast_mask, TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { - return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + return {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; } // Generate the TMA coord tensor @@ -281,7 +291,7 @@ struct Copy_Traits // The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask template struct Copy_Traits - : TMA_LOAD_Unpack + : TMA_LOAD_Unpack { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -298,43 +308,17 @@ struct Copy_Traits uint16_t, // multicast mask uint64_t // cache hint > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint) + : opargs_(desc, mbar, mask, hint) {} }; ////////////////////////////////////////////////////////////////////////////// ///////////////////////////// TMA_STORE ////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// -// Utility for unpacking TMA_STORE arguments into a CopyOp -template -struct TMA_STORE_Unpack -{ - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const& traits, - Tensor const& src, - Tensor & dst) - { - static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); - - void const* const desc_ptr = traits.tma_desc_; - void const* const src_ptr = cute::raw_pointer_cast(src.data()); - auto dst_coord = dst.data().coord_; -#if 0 - auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); - printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", - threadIdx.x, threadIdx.y, threadIdx.z, - blockIdx.x, blockIdx.y, blockIdx.z, - int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); -#endif - return detail::explode_tuple(detail::CallCOPY{}, - make_tuple(desc_ptr, src_ptr), seq<0,1>{}, - dst_coord, tuple_seq{}); - } -}; - -struct SM90_TMA_STORE_OP : SM90_TMA_STORE {}; +struct SM90_TMA_STORE_PTR : SM90_TMA_STORE {}; // The executable SM90_TMA_STORE with tma_desc template @@ -369,6 +353,13 @@ struct Copy_Traits return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); } + // Construct new TMA_STORE with (unsafe) swapped out TMA descriptor ptr (for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } + template CUTE_HOST_DEVICE friend constexpr void @@ -393,19 +384,11 @@ struct Copy_Traits make_tuple(desc_ptr, src_ptr), seq<0,1>{}, dst_coord, tuple_seq{}); } - - // Construct Copy_Traits executable (w/ swapped out TMA descriptor) for SM90_TMA_STORE (for grouped gemm/ptr array gemm) - CUTE_HOST_DEVICE constexpr - Copy_Traits - with(TmaDescriptor const* new_tma_desc) const { - return {{}, new_tma_desc}; - } }; -// The executable SM90_TMA_STORE with tma_desc +// Same as SM90_TMA_STORE, but with an unsafe TMA Desc PTR instead template -struct Copy_Traits - : TMA_STORE_Unpack +struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit @@ -417,6 +400,31 @@ struct Copy_Traits // SM90_TMA_STORE arguments TmaDescriptor const* tma_desc_; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = traits.tma_desc_; + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } }; ////////////////////////////////////////////////////////////////////////////// @@ -520,7 +528,7 @@ struct Copy_Traits CUTE_HOST_DEVICE constexpr Copy_Traits with(uint64_t& bulk_mbar) const { - return {{&bulk_mbar}}; + return {&bulk_mbar}; } template CUTE_HOST_DEVICE constexpr Copy_Traits with(uint64_t& bulk_mbar) const { - return {{&bulk_mbar}}; + return {&bulk_mbar}; } }; @@ -1391,19 +1399,46 @@ tma_partition(Copy_Atom const& copy_atom, return cute::make_tuple(gresult, sresult); } +// Explicit defaults for cta_coord and cta_layout +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + return tma_partition(copy_atom, Int<0>{}, Layout<_1,_0>{}, stensor, gtensor); +} + // TMA Multicast Masks Calculation template CUTE_HOST_DEVICE constexpr -auto +uint16_t create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, CtaCoord const& cta_coord_vmnk) { auto cta_coord_slicer = replace(cta_coord_vmnk, _); auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_slicer, cta_layout_vmnk); - // Get the instruction code + uint16_t mcast_mask = 0; - for (int i = 0; i < size(cta_layout); ++i) { - mcast_mask |= uint16_t(1) << cta_layout(i); + if constexpr (rank_v == 1 and depth_v <= 1 and + not is_static::value) { + // Get the instruction code -- optimized for dynamic flat-rank-1 cta_layout + mcast_mask = uint16_t(1); + // Smear by stride<0> (may want to predicate on stride<0> mag?) + mcast_mask |= mcast_mask << (1*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (2*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (4*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (8*stride<0>(cta_layout)); + // Select shape<0> + mcast_mask &= (uint16_t(-1) >> (16 - shape<0>(cta_layout) * stride<0>(cta_layout))); + } else { + // Get the instruction code -- generic path + for (int i = 0; i < size(cta_layout); ++i) { + mcast_mask |= uint16_t(1) << cta_layout(i); + } } // Shift by the instruction's elected block rank (dynamic) mcast_mask <<= elected_cta; diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index bf408274..7cb4fe3d 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -250,12 +250,12 @@ struct TiledMMA : MMA_Atom auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) // Tile the tensor for the Atom - auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), make_layout(size<1>(AtomShape_MNK{}))); - auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomN),(RestM,RestN)) + auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN)) // Transform the Atom mode from (M,K) to (Thr,Val) - auto tv_tensor = a_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) + auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) // Tile the tensor for the C-threads auto thr_tile = make_tile(_, @@ -604,16 +604,15 @@ CUTE_HOST_DEVICE constexpr auto partition_shape_C(TiledMMA const& mma, Shape_MN const& shape_MN) { - constexpr int R = rank_v; - static_assert(R >= 2, "Must have at least rank-2"); - auto atomMNK = typename TiledMMA::AtomShape_MNK{}; - auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; - auto V = shape<1>(typename TiledMMA::AtomLayoutC_TV{}); - auto M = shape_div(size<0>(shape_MN), size<0>(atomMNK) * size<1>(thrVMNK)); - auto N = shape_div(size<1>(shape_MN), size<1>(atomMNK) * size<2>(thrVMNK)); - return cute::tuple_cat(make_shape(V,M,N), take<2,R>(shape_MN)); + auto dummy = make_layout(shape(shape_MN)); + auto dummy_tv = mma.thrfrg_C(dummy); + // Slice+rearrange like partition_C + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + } + template CUTE_HOST_DEVICE constexpr auto @@ -632,14 +631,12 @@ CUTE_HOST_DEVICE constexpr auto partition_shape_A(TiledMMA const& mma, Shape_MK const& shape_MK) { - constexpr int R = rank_v; - static_assert(R >= 2, "Must have at least rank-2"); - auto atomMNK = typename TiledMMA::AtomShape_MNK{}; - auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; - auto V = shape<1>(typename TiledMMA::AtomLayoutA_TV{}); - auto M = shape_div(size<0>(shape_MK), size<0>(atomMNK) * size<1>(thrVMNK)); - auto K = shape_div(size<1>(shape_MK), size<2>(atomMNK) * size<3>(thrVMNK)); - return cute::tuple_cat(make_shape(V,M,K), take<2,R>(shape_MK)); + auto dummy = make_layout(shape(shape_MK)); + auto dummy_tv = mma.thrfrg_A(dummy); + // Slice+rearrange like partition_A + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + } template @@ -647,14 +644,12 @@ CUTE_HOST_DEVICE constexpr auto partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) { - constexpr int R = rank_v; - static_assert(R >= 2, "Must have at least rank-2"); - auto atomMNK = typename TiledMMA::AtomShape_MNK{}; - auto thrVMNK = typename TiledMMA::ThrLayoutVMNK{}; - auto V = shape<1>(typename TiledMMA::AtomLayoutB_TV{}); - auto N = shape_div(size<0>(shape_NK), size<1>(atomMNK) * size<2>(thrVMNK)); - auto K = shape_div(size<1>(shape_NK), size<2>(atomMNK) * size<3>(thrVMNK)); - return cute::tuple_cat(make_shape(V,N,K), take<2,R>(shape_NK)); + auto dummy = make_layout(shape(shape_NK)); + auto dummy_tv = mma.thrfrg_B(dummy); + // Slice+rearrange like partition_B + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + } // diff --git a/include/cute/atom/mma_traits_sm80.hpp b/include/cute/atom/mma_traits_sm80.hpp index 706b10d8..5f7e73e4 100644 --- a/include/cute/atom/mma_traits_sm80.hpp +++ b/include/cute/atom/mma_traits_sm80.hpp @@ -419,6 +419,203 @@ template <> struct MMA_Traits : MMA_Traits {}; +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V8) -> (M8,N32) + using ALayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V16) -> (M16,N32) + using ALayout = Layout, Shape < _8, _2>>, + Stride, Stride<_16, _8>>>; + // (T32,V8) -> (M8,N32) + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,N64) + using ALayout = Layout, Shape < _8, _2, _2>>, + Stride, Stride<_16, _8, _512>>>; + // (T32,V16) -> (M8,N64) + using BLayout = Layout, Shape <_8, _2>>, + Stride, Stride<_8, _256>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + /////////////////////////////////////////////////////////////////////////////// /////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// /////////////////////////////////////////////////////////////////////////////// @@ -440,9 +637,13 @@ struct MMA_Traits using CLayout = SM80_16x8_Row; }; +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 & b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + template <> struct MMA_Traits - :MMA_Traits {}; + : MMA_Traits {}; template<> struct MMA_Traits @@ -455,7 +656,7 @@ struct MMA_Traits using Shape_MNK = Shape<_8,_8,_128>; using ThrID = Layout<_32>; using ALayout = Layout,_32>, - Stride,_8>>; + Stride,_8>>; using BLayout = Layout,_32>, Stride,_8>>; using CLayout = SM80_8x8_Row; @@ -472,7 +673,7 @@ struct MMA_Traits using ValTypeA = cute::uint1b_t; using ValTypeB = cute::uint1b_t; using ValTypeC = int32_t; - + using Shape_MNK = Shape<_16,_8,_128>; using ThrID = Layout<_32>; using ALayout = Layout,Shape<_32,_2>>, diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index b02f5b3a..8f59ff55 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -1128,7 +1128,6 @@ struct MMA_Traits> //////////////////////////////////////////////////////////////////////////////////////////////////// - template < GMMA::Major tnspA, GMMA::Major tnspB, diff --git a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp index 27c41ad3..161dc7ec 100644 --- a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -7735,4 +7735,4 @@ struct MMA_Traits +# include #endif // _MSC_VER #if defined(__CUDACC_RTC__) diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 57db56ab..48d416f4 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -100,20 +100,30 @@ public: // Copy Ctor CUTE_HOST_DEVICE constexpr - subbyte_reference(subbyte_reference const& other) { - *this = element_type(other); + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); } // Copy Assignment CUTE_HOST_DEVICE constexpr - subbyte_reference& operator=(subbyte_reference const& other) { - return *this = element_type(other); + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); } // Assignment template CUTE_HOST_DEVICE constexpr - enable_if_t, subbyte_reference&> operator=(element_type x) + enable_if_t, subbyte_reference&> operator=(value_type x) { static_assert(is_same_v, "Do not specify template arguments!"); storage_type item = (reinterpret_cast(x) & BitMask); @@ -149,11 +159,11 @@ public: // Value CUTE_HOST_DEVICE - element_type get() const + value_type get() const { if constexpr (is_same_v) { // Extract to bool -- potentially faster impl return bool((*ptr_) & (BitMask << idx_)); - } else { // Extract to element_type + } else { // Extract to value_type // Extract from the current storage element auto item = storage_type((ptr_[0] >> idx_) & BitMask); @@ -165,13 +175,13 @@ public: item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); } - return reinterpret_cast(item); + return reinterpret_cast(item); } } - // Extract to type element_type + // Extract to type value_type CUTE_HOST_DEVICE constexpr - operator element_type() const { + operator value_type() const { return get(); } @@ -341,6 +351,14 @@ recast_ptr(subbyte_iterator const& x) { CUTE_GCC_UNREACHABLE; } +// Dynamic pointers have unknown static alignment +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(subbyte_iterator const& x) { + return {}; +} + template CUTE_HOST_DEVICE void print(subbyte_iterator const& x) { @@ -352,6 +370,7 @@ CUTE_HOST_DEVICE void print(subbyte_reference const& x) { print(x.get()); } + // // array_subbyte // Statically sized array for non-byte-aligned data types diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index bc1b54ef..26195a47 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1830,7 +1830,7 @@ recast_layout(Layout const& layout) return upcast(layout); } else { - static_assert(dependent_false, "Recast not supported."); + return downcast(upcast(layout)); } CUTE_GCC_UNREACHABLE; diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index 3e5f8362..26ae8dc7 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -616,7 +616,7 @@ recast_layout(ComposedLayout const& layout) return upcast(layout); } else { - static_assert(dependent_false, "Recast not supported."); + return downcast(upcast(layout)); } CUTE_GCC_UNREACHABLE; } @@ -631,6 +631,15 @@ max_alignment(ComposedLayout const& layout) return Int<1>{}; } +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + return Layout<_1,_0>{}; +} + // // Display utilities // diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp index 1b143253..a614bdb2 100644 --- a/include/cute/numeric/integral_ratio.hpp +++ b/include/cute/numeric/integral_ratio.hpp @@ -154,13 +154,6 @@ operator*(C, R) { return {}; } -template -CUTE_HOST_DEVICE constexpr -typename R::type -operator/(C, R) { - return {}; -} - // Product with dynamic type needs to produce an integer... template ::value)> @@ -179,6 +172,13 @@ operator*(R, C const& c) { return c * R::num / R::den; } +template +CUTE_HOST_DEVICE constexpr +auto +operator/(C const& c, R) { + return c * R{}; +} + template CUTE_HOST_DEVICE constexpr typename R::type @@ -200,6 +200,10 @@ operator+(C, R) { return {}; } +///////////////// +// Comparisons // +///////////////// + template CUTE_HOST_DEVICE constexpr bool_constant::num == R::num && R::den == R::den> @@ -221,6 +225,31 @@ operator==(C, R) { return {}; } +/////////////////////// +// Special functions // +/////////////////////// + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(C, R) { + return {}; +} + template CUTE_HOST_DEVICE constexpr typename R::type diff --git a/include/cute/numeric/numeric_types.hpp b/include/cute/numeric/numeric_types.hpp index 07444331..b9943b8c 100644 --- a/include/cute/numeric/numeric_types.hpp +++ b/include/cute/numeric/numeric_types.hpp @@ -46,6 +46,7 @@ template static constexpr auto sizeof_bits_v = sizeof_bits::value; using cutlass::bits_to_bytes; +using cutlass::bytes_to_bits; using cutlass::is_subbyte; diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 4cfa129c..cc49b6a3 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -214,6 +214,14 @@ make_smem_ptr(void const* ptr) { return make_smem_ptr(recast_ptr(ptr)); } +// nullptr_t overload for make_smem_ptr(nullptr) disambiguation +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(decltype(nullptr)) { // nullptr_t + return make_smem_ptr(recast_ptr(nullptr)); +} + // The smem tag is invariant over type-recast template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp index 90ca0ceb..57ad0b3c 100644 --- a/include/cute/pointer_base.hpp +++ b/include/cute/pointer_base.hpp @@ -30,9 +30,10 @@ **************************************************************************************************/ #pragma once -#include // CUTE_HOST_DEVICE -#include // cute::sizeof_bits -#include // cute::declval, cute::void_t, etc +#include // CUTE_HOST_DEVICE +#include // cute::sizeof_bits +#include // Int<0> +#include // cute::declval, cute::void_t, etc namespace cute { @@ -115,6 +116,14 @@ raw_pointer_cast(T* ptr) { return ptr; } +// The statically-known alignment of a dynamic pointer is unknown +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(T*) { + return {}; +} + // // A very simplified iterator adaptor. // Derived classed may override methods, but be careful to reproduce interfaces exactly. @@ -169,6 +178,13 @@ raw_pointer_cast(iter_adaptor const& x) { return raw_pointer_cast(x.ptr_); } +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(iter_adaptor const& x) { + return max_alignment(x.ptr_); +} + // // counting iterator -- quick and dirty // diff --git a/include/cute/pointer_swizzle.hpp b/include/cute/pointer_swizzle.hpp index 720b9b12..1a802cfd 100644 --- a/include/cute/pointer_swizzle.hpp +++ b/include/cute/pointer_swizzle.hpp @@ -147,6 +147,14 @@ recast_ptr(swizzle_ptr const& ptr) { return make_swizzle_ptr(recast_ptr(ptr.get()), SwizzleFn{}); } +// The statically-known alignment of a swizzle pointer is the alignment of the swizzle function converted to bits +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(swizzle_ptr const&) { + return Int<8>{} * max_alignment(SwizzleFn{}); +} + // // Display utilities // diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 1324360e..7f7161bc 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -447,7 +447,7 @@ recast_layout(Swizzle const& swizzle) return upcast(swizzle); } else { - static_assert(dependent_false, "Recast not supported."); + return downcast(upcast(layout)); } CUTE_GCC_UNREACHABLE; } @@ -457,7 +457,7 @@ CUTE_HOST_DEVICE constexpr auto max_alignment(Swizzle const&) { - return Int<1 << M>{}; + return Int<(1 << M)>{}; } template diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index 3564c667..2be19c15 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -84,6 +84,8 @@ struct ArrayEngine }; // Specialization for sparse_elem tensor allocation/iteration +// NOTE: This can and should be used for allocation of SMEM as well! +// Fuse these two ArrayEngines? template struct ArrayEngine, N> { @@ -858,6 +860,17 @@ max_common_layout(Tensor const& a, CUTE_GCC_UNREACHABLE; } +/* Return the maximum (statically known) alignment of a Tensor in the number of bits + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Tensor const& t) +{ + return gcd(max_alignment(t.data()), + max_alignment(t.layout()) * static_value>()); +} + // // Key algebraic operations -- Composition, Divide, and Product // diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp index 86da7cae..26454443 100644 --- a/include/cute/util/debug.hpp +++ b/include/cute/util/debug.hpp @@ -123,7 +123,7 @@ bool block([[maybe_unused]] int bid) { #if defined(__CUDA_ARCH__) - return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid; + return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == static_cast(bid); #else return true; #endif @@ -134,7 +134,7 @@ bool thread([[maybe_unused]] int tid, [[maybe_unused]] int bid) { #if defined(__CUDA_ARCH__) - return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid); + return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == static_cast(tid)) && block(bid); #else return true; #endif diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index e663b569..a3074ef9 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -141,9 +141,15 @@ using CUTE_STL_NAMESPACE::common_type_t; using CUTE_STL_NAMESPACE::remove_pointer; using CUTE_STL_NAMESPACE::remove_pointer_t; +using CUTE_STL_NAMESPACE::add_pointer; +using CUTE_STL_NAMESPACE::add_pointer_t; + using CUTE_STL_NAMESPACE::alignment_of; using CUTE_STL_NAMESPACE::alignment_of_v; +using CUTE_STL_NAMESPACE::is_pointer; +using CUTE_STL_NAMESPACE::is_pointer_v; + // using CUTE_STL_NAMESPACE::declval; diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index c9689732..460531aa 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -47,6 +47,99 @@ namespace cutlass { namespace arch { //////////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_DEVICE void fence_view_async_shared(); + +namespace detail { // namespace detail begin + +// Single threaded versions that need to be called in an elect_one region +template +CUTLASS_DEVICE +void initialize_barrier_array(T ptr, int arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + ptr[i].init(arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array(uint64_t *ptr, int arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + T::init(&ptr[i], arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + full_barriers[i].init(full_barrier_arv_cnt); + empty_barriers[i].init(empty_barrier_arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); + EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); + } +} + +// Aligned versions that need to be call warp wide +template +CUTLASS_DEVICE +void initialize_barrier_array_aligned(T ptr, int arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + ptr[i].init(arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_aligned(uint64_t *ptr, int arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + T::init(&ptr[i], arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair_aligned(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + full_barriers[i].init(full_barrier_arv_cnt); + empty_barriers[i].init(empty_barrier_arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair_aligned(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); + EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); + } + } +} + +} // namespace detail end + + // Enumerates the reserved named barriers to avoid potential conflicts // This enum class specifies the NamedBarriers reserved by CUTLASS. enum class ReservedNamedBarriers { diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h index b0f75006..0fc60f41 100644 --- a/include/cutlass/arch/config.h +++ b/include/cutlass/arch/config.h @@ -35,6 +35,8 @@ #pragma once +#include "cutlass/platform/platform.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// // SM90 @@ -79,3 +81,5 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h index 6b487a73..0e957c72 100644 --- a/include/cutlass/arch/memory_sm75.h +++ b/include/cutlass/arch/memory_sm75.h @@ -35,6 +35,7 @@ #pragma once #include "cutlass/array.h" +#include "cutlass/detail/helper_macros.hpp" #include "cutlass/layout/matrix.h" #include "cute/arch/copy_sm75.hpp" #include "cute/arch/util.hpp" @@ -50,7 +51,7 @@ template < /// .x1, .x2, or .x4 int MatrixCount > -inline __device__ void ldsm(Array & D, void const* ptr); +CUTLASS_DEVICE void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -59,19 +60,19 @@ inline __device__ void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// /// CUTLASS helper to get SMEM pointer -inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { +CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { return cute::cast_smem_ptr_to_uint(ptr); } /// CUTLASS helper to get SMEM pointer -inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { +CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void const *ptr) { return cutlass_get_smem_pointer(const_cast(ptr)); } ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -95,7 +96,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -119,7 +120,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -147,7 +148,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -171,7 +172,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { @@ -195,7 +196,7 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { diff --git a/include/cutlass/arch/mma_sm70.h b/include/cutlass/arch/mma_sm70.h index 6471de8a..28bb4638 100644 --- a/include/cutlass/arch/mma_sm70.h +++ b/include/cutlass/arch/mma_sm70.h @@ -33,11 +33,7 @@ */ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 6cced190..a39ededb 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/arch/wmma.h" diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index f990c1ac..19d78bf2 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "mma.h" diff --git a/include/cutlass/arch/mma_sm89.h b/include/cutlass/arch/mma_sm89.h index fe4b7eb7..d8a75b66 100644 --- a/include/cutlass/arch/mma_sm89.h +++ b/include/cutlass/arch/mma_sm89.h @@ -35,11 +35,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "mma.h" diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h index 1183ee5e..16108f0a 100644 --- a/include/cutlass/arch/mma_sm90.h +++ b/include/cutlass/arch/mma_sm90.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sparse_sm80.h b/include/cutlass/arch/mma_sparse_sm80.h index 7041d04d..ed2a5ad0 100644 --- a/include/cutlass/arch/mma_sparse_sm80.h +++ b/include/cutlass/arch/mma_sparse_sm80.h @@ -35,11 +35,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sparse_sm89.h b/include/cutlass/arch/mma_sparse_sm89.h index c092df76..2fae35be 100644 --- a/include/cutlass/arch/mma_sparse_sm89.h +++ b/include/cutlass/arch/mma_sparse_sm89.h @@ -35,11 +35,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/simd.h b/include/cutlass/arch/simd.h index 3104746e..f670fc29 100644 --- a/include/cutlass/arch/simd.h +++ b/include/cutlass/arch/simd.h @@ -34,8 +34,8 @@ #pragma once -#include "../array.h" -#include "../numeric_types.h" +#include "cutlass/arch/array.h" +#include "cutlass/arch/numeric_types.h" namespace cutlass { namespace arch { diff --git a/include/cutlass/arch/synclog.hpp b/include/cutlass/arch/synclog.hpp index ea683859..8cf65ad7 100644 --- a/include/cutlass/arch/synclog.hpp +++ b/include/cutlass/arch/synclog.hpp @@ -59,7 +59,7 @@ constexpr uint32_t synclog_cap = 1 << 26; inline std::mutex synclog_mutex; inline std::vector synclog_buf_list; #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -inline __device__ uint32_t* synclog_buf; +CUTLASS_DEVICE uint32_t* synclog_buf; #endif CUTLASS_DEVICE diff --git a/include/cutlass/arch/wmma_sm70.h b/include/cutlass/arch/wmma_sm70.h index 19fda4f8..d75ee2b0 100644 --- a/include/cutlass/arch/wmma_sm70.h +++ b/include/cutlass/arch/wmma_sm70.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm72.h b/include/cutlass/arch/wmma_sm72.h index 4a268905..b644181b 100644 --- a/include/cutlass/arch/wmma_sm72.h +++ b/include/cutlass/arch/wmma_sm72.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index 4663e95c..f6036051 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 62e94694..e85d19fa 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -2573,20 +2573,8 @@ Array fma(Array const &a, Array const &b, T c) { return op(a, b, c); } - //////////////////////////////////////////////////////////////////////////////////////////////////// - - - -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/array_subbyte.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { + //////////////////////////////////////////////////////////////////////////////////////////////////// // AlignedArray @@ -2606,9 +2594,10 @@ public: }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/array_subbyte.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index eb77a931..d2e0e5ef 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -554,6 +554,8 @@ private: //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/blas3.h b/include/cutlass/blas3.h index ee5587d1..d41f1ee6 100644 --- a/include/cutlass/blas3.h +++ b/include/cutlass/blas3.h @@ -132,7 +132,7 @@ struct MantissaInBits { template <> struct MantissaInBits> { static int constexpr bits = 30; - static double constexpr error = 1.0e-15; + static double constexpr error = 1.0e-14; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 78862b0a..0e5d898d 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -189,7 +189,7 @@ private: -problem_shape.dilation[NumSpatialDimensions-1-i] : problem_shape.dilation[NumSpatialDimensions-1-i]; } - + return make_im2col_tma_copy( GmemTiledCopyA{}, tensor_a, @@ -225,7 +225,7 @@ private: auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); auto lower_srt = detail::compute_lower_srt(problem_shape); - + return make_im2col_tma_copy( GmemTiledCopyB{}, tensor_b, @@ -372,6 +372,96 @@ public: return false; } + if (is_im2col_A || is_im2col_B) { + // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] + constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); + } + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) + if constexpr (ConvOp == conv::Operator::kWgrad) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::ostringstream os; +#endif + const auto & input_shape = problem_shape.shape_A; + const auto & input_stride = problem_shape.stride_A; + + implementable &= input_stride[ProblemShape::RankT - 1] == 1; + int input_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + input_shape_size *= input_shape[i + 1]; + implementable &= input_stride[i] == input_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (input_stride[i] != input_shape_size) { + os << "\n *** input_stride[" << i << "] = " << input_stride[i] << " != input_shape_size = " << input_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n input_shape_size: " << input_shape_size + << "\n input_shape: " << input_shape + << "\n input_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + + const auto & output_shape = problem_shape.shape_C; + const auto & output_stride = problem_shape.stride_C; + + implementable &= output_stride[ProblemShape::RankT - 1] == 1; + int output_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + output_shape_size *= output_shape[i + 1]; + implementable &= output_stride[i] == output_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (output_stride[i] != output_shape_size) { + os << "\n *** output_stride[" << i << "] = " << output_stride[i] << " != output_shape_size = " << output_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n output_shape_size: " << input_shape_size + << "\n output_shape: " << input_shape + << "\n output_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + } + + // Conv kernels only support cross correlation mode currently. + implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); + return false; + } + if (problem_shape.groups > 1) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); return false; @@ -516,9 +606,9 @@ public: // Issue the epilogue waits if (lane_predicate) { /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all + * Waits for all stages to either be released (all * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was + * then would just be acquired since the phase was * still inverted from make_producer_start_state */ pipeline.producer_tail(smem_pipe_producer_state); @@ -645,7 +735,7 @@ public: k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); - + // Wait on all GMMAs to complete warpgroup_wait<0>(); diff --git a/include/cutlass/conv/convnd_problem_shape.hpp b/include/cutlass/conv/convnd_problem_shape.hpp index ffcc547f..cd2f674f 100644 --- a/include/cutlass/conv/convnd_problem_shape.hpp +++ b/include/cutlass/conv/convnd_problem_shape.hpp @@ -319,6 +319,7 @@ struct ConvProblemShape { // | ShapeB | KTRSC | KTRSC | NDHWC | // | ShapeC | NZPQK | NDHWC | KTRSC | // + // Input comes from calculate_xformed_act, which does NOT depend on ConvOp. CUTLASS_HOST_DEVICE constexpr void set_shape_stride_ABC( @@ -328,6 +329,31 @@ struct ConvProblemShape { TensorStride stride_flt, TensorExtent shape_xformed_act, TensorStride stride_xformed_act) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("*** set_shape_stride_ABC ***"); + printf("\n shape_act: "); + print(shape_act); + printf("\n stride_act: "); + print(stride_act); + printf("\n shape_flt: "); + print(shape_flt); + printf("\n stride_flt: "); + print(stride_flt); + printf("\n shape_xformed_act: "); + print(shape_xformed_act); + printf("\n stride_xformed_act: "); + print(stride_xformed_act); + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + printf("\n ConvOp: Fprop"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + printf("\n ConvOp: Dgrad"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + printf("\n ConvOp: Wgrad"); + } + printf("\n"); +#endif if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { shape_A = shape_act; @@ -353,6 +379,20 @@ struct ConvProblemShape { shape_C = shape_flt; stride_C = stride_flt; } +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n shape_A: "); + print(shape_A); + printf("\n stride_A: "); + print(stride_A); + printf("\n shape_B: "); + print(shape_B); + printf("\n stride_B: "); + print(stride_B); + printf("\n shape_C: "); + print(shape_C); + printf("\n stride_C: "); + print(stride_C); +#endif } // Get A extents. diff --git a/include/cutlass/conv/kernel/direct_convolution.h b/include/cutlass/conv/kernel/direct_convolution.h index 5e429956..d4e98fa4 100644 --- a/include/cutlass/conv/kernel/direct_convolution.h +++ b/include/cutlass/conv/kernel/direct_convolution.h @@ -40,6 +40,7 @@ #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" #include "cutlass/semaphore.h" #include "cutlass/tensor_ref.h" #include "cutlass/layout/tensor.h" @@ -155,7 +156,7 @@ struct DirectConvolutionParams { swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); // Dynamic SMEM usage because stride and dilation are runtime params. - smem_size_ = (max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); + smem_size_ = (cutlass::platform::max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index d778046c..fe884d70 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -37,7 +37,7 @@ #if defined(__CUDACC_RTC__) #include #else -#include +#include #endif #include "cutlass/cutlass.h" diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index 1c8f56a6..2adfd266 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -85,7 +85,11 @@ namespace cutlass { #if !defined(__CUDACC_RTC__) +#if ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) #include +#endif // (__CUDACC_VERSION__ >= 11.8) + #include #define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok @@ -100,7 +104,8 @@ namespace cutlass { #else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) -#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) +#if ((__CUDACC_VER_MAJOR__ >= 13) || \ + ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) \ #define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ template \ @@ -138,7 +143,7 @@ namespace cutlass { return reinterpret_cast(pfn)(args...); \ } -#endif // (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) +#endif // (__CUDACC_VERSION__ >= 12.5) #endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp index a4b288e7..9d8f9e2f 100644 --- a/include/cutlass/detail/collective.hpp +++ b/include/cutlass/detail/collective.hpp @@ -31,6 +31,7 @@ #pragma once #include "cute/container/tuple.hpp" +#include "cute/layout.hpp" // cute::size(shape) ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index f9f348b9..c740eb98 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -237,7 +237,7 @@ struct LayoutAwareConvertImpl< } }; -// Specialization for UINT4 -> FPF16 with [02461357] value order +// Specialization for UINT4 -> FP16 with [02461357] value order template <> struct LayoutAwareConvertImpl< cutlass::uint4b_t, @@ -754,7 +754,6 @@ public: cute::tuple& partitioned_extra_info, int const k_block) { - static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); static_assert(cosize_v == cosize_v); @@ -805,14 +804,15 @@ public: { auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); + constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; asm volatile( "{\n" - " and .b32 %0, %2, %4 ;\n" \ - " and .b32 %1, %3, %5 ;\n" \ + " lop3 .b32 %0, %2, %4, %5, %6;\n" \ + " xor .b32 %1, %3, %5; \n" \ "}\n" : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) - : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0x7F7F7F00), "n"(0x7F7F7F7F) - ); + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) + ); } } CUTLASS_PRAGMA_UNROLL diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp index 4cd895f1..039f5e84 100644 --- a/include/cutlass/detail/helper_macros.hpp +++ b/include/cutlass/detail/helper_macros.hpp @@ -57,6 +57,12 @@ #define CUTLASS_DEVICE inline #endif +#if ! defined(_MSC_VER) +#define CUTLASS_LAMBDA_FUNC_INLINE __attribute__((always_inline)) +#else +#define CUTLASS_LAMBDA_FUNC_INLINE [[msvc::forceinline]] +#endif + #define CUTLASS_HOST __host__ #define CUTLASS_GLOBAL __global__ static @@ -74,11 +80,11 @@ CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) #ifdef _MSC_VER // Provides support for alternative operators 'and', 'or', and 'not' -#include +#include #endif // _MSC_VER #if !defined(__CUDACC_RTC__) -#include +#include #endif #if defined(__CUDA_ARCH__) diff --git a/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp b/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp new file mode 100644 index 00000000..914443dd --- /dev/null +++ b/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Mainloop Fusion configs specific for scale factors +*/ + +#pragma once + +#include // cute::void_t + +namespace cutlass::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ElementSFType { + using type = void; +}; + +template +struct ElementSFType> { + using type = typename CollectiveMainloop::ElementSF; +}; + +template +struct LayoutSFAType { + using type = void; +}; + +template +struct LayoutSFAType> { + using type = typename CollectiveMainloop::LayoutSFA; +}; + +template +struct LayoutSFBType { + using type = void; +}; + +template +struct LayoutSFBType> { + using type = typename CollectiveMainloop::LayoutSFB; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index 7af5d96c..cc7caede 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -34,8 +34,11 @@ #pragma once +#include // CUTLASS_HOST_DEVICE +#include // uint64_t + // __grid_constant__ was introduced in CUDA 11.7. -#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) && !CUTLASS_CLANG_CUDA # define CUTLASS_GRID_CONSTANT_SUPPORTED #endif diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 759591b5..720dcc00 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -422,7 +422,8 @@ struct CollectiveBuilder< Schedule, fusion::LinearCombination, cute::enable_if_t || - cute::is_same_v >> { + cute::is_same_v || + cute::is_same_v >> { // Passing void C disables source load using ElementC = cute::conditional_t, diff --git a/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/include/cutlass/epilogue/collective/default_epilogue_array.hpp index 0f6f3293..da7562b4 100644 --- a/include/cutlass/epilogue/collective/default_epilogue_array.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -86,7 +86,7 @@ public: static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; - static_assert(cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -198,20 +198,30 @@ public: assert(0); } - InternalStrideC stride_c; - InternalStrideD stride_d; - if constexpr (!cute::is_same_v) { - // If grouped gemm - if (epilogue_op.is_source_needed()) { - stride_c = detail::get_epilogue_stride(params.dC[l_coord]); + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); } - stride_d = detail::get_epilogue_stride(params.dD[l_coord]); - } - else { - stride_c = detail::get_epilogue_stride(params.dC); - stride_d = detail::get_epilogue_stride(params.dD); - } - + }(); + // Represent the full output tensor ElementC const* ptr_C_l = nullptr; if (epilogue_op.is_source_needed()) { diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 6c0368e0..23e57d99 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -157,7 +157,8 @@ struct EmptyStorage { template CUTLASS_HOST_DEVICE auto get_epilogue_stride(Stride stride){ - if constexpr (cute::is_base_of_v) { + if constexpr (cute::is_base_of_v|| + cute::is_base_of_v) { return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); } else { @@ -464,7 +465,7 @@ public: tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } }; -// SFINAE helpers for detecting beta/beta_ptr in EVT arguments. +// SFINAE helpers for detecting beta/beta_ptr/beta_ptr_array in EVT arguments. template struct has_beta { static constexpr bool value = false; @@ -485,6 +486,16 @@ struct has_beta_ptr +struct has_beta_ptr_array { + static constexpr bool value = false; +}; + +template +struct has_beta_ptr_array> { + static constexpr bool value = true; +}; + } // namespace detail } // namespace collective } // namespace epilogue diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index 84b6e14e..54fe9b1d 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -328,7 +328,7 @@ public: } uint32_t transaction_bytes = TmaTransactionBytes; - typename Params::TMA_C tma_load_c = {}; + typename Params::TMA_C tma_load_c{}; if constexpr (is_source_supported) { ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); @@ -409,7 +409,7 @@ public: implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); } - if constexpr (not cute::is_void_v) { + if constexpr (is_source_supported) { constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); @@ -432,13 +432,16 @@ public: bool beta_implementable = true; - if constexpr (cute::is_void_v) { + if (cute::is_void_v || args.ptr_C == nullptr) { if constexpr (detail::has_beta::value) { beta_implementable = args.thread.beta == 0.0; } if constexpr (detail::has_beta_ptr::value) { beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; } + if constexpr (detail::has_beta_ptr_array::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; + } } if (!beta_implementable) { @@ -775,7 +778,7 @@ public: tRS_rC, thread_idx }; - auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); @@ -1017,7 +1020,7 @@ public: Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) if constexpr (IsLoad) { - if (not cute::is_void_v) { + if (is_source_supported) { constexpr int C_tensormap_index = NumEpilogueWarpGroups; Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{}); @@ -1058,8 +1061,10 @@ public: // Replacing global_address for the next batch if constexpr (IsLoad) { if constexpr (is_source_supported) { - cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, - params.ptr_C[next_batch]); + if (params.ptr_C != nullptr) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, + params.ptr_C[next_batch]); + } } } else if constexpr (is_destination_supported) { @@ -1087,18 +1092,20 @@ public: if constexpr (IsLoad) { if constexpr (is_source_supported) { - ElementC const* ptr_C = nullptr; - Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + if (params.dC != nullptr) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); - cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, - prob_shape, prob_stride); - // Convert strides to byte strides - for (uint64_t& stride : prob_stride) { - stride = (stride * sizeof_bits_v) / 8; + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); } - cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, - prob_shape, - prob_stride); } } else if constexpr (is_destination_supported) { @@ -1166,7 +1173,7 @@ public: void tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { if constexpr (IsLoad) { - if constexpr (not cute::is_void_v) { + if constexpr (is_source_supported) { cute::tma_descriptor_fence_acquire(tensormap); } } diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index b96c4aea..b3c7bf38 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -94,7 +94,7 @@ class CollectiveEpilogue< SmemLayoutAtomD_, CopyOpR2S_, CopyAtomC_, - CopyOpR2R_, + CopyOpR2R_ > { public: // @@ -136,6 +136,9 @@ private: static_assert(not cute::is_void_v, "SmemElementD is void"); using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + using TmaElementD = cute::conditional_t>, uint64_t, NonVoidElementD>; + using TmaElementC = cute::conditional_t>, uint64_t, NonVoidElementC>; + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; @@ -239,14 +242,14 @@ public: struct Params { using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideC{}, int32_t(0)), StrideC{}), take<0,2>(SmemLayoutC{}), EpilogueTile{}, _1{})); using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideD{}, int32_t(0)), StrideD{}), take<0,2>(SmemLayoutD{}), EpilogueTile{}, @@ -273,9 +276,9 @@ public: auto [M, N, K, L] = problem_shape_MNKL; uint32_t transaction_bytes = TmaTransactionBytes; - typename Params::TMA_C tma_load_c = {}; + typename Params::TMA_C tma_load_c{}; if constexpr (is_source_supported) { - Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); tma_load_c = make_tma_copy_C_sm90( CopyOpG2S{}, tensor_c, @@ -285,7 +288,7 @@ public: typename Params::TMA_D tma_store_d; if constexpr (is_destination_supported) { - Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); tma_store_d = make_tma_copy_C_sm90( CopyOpS2G{}, tensor_d, @@ -644,7 +647,18 @@ public: // Absolute coordinate tensors (dynamic) Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) - Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + Tensor tRS_cD_mn = [&]() { + if constexpr (IsUseR2R) { + // (t)hread-partition for ConsumerStoreCallbacks. + TiledCopy tiled_cst = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_cst = tiled_cst.get_slice(thread_idx); + + return thread_cst.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + else { + return thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + }(); // Relative coordinate tensors (static) Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index f829a2ff..a5f47f08 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -50,6 +50,7 @@ struct EpilogueSimtVectorized {}; struct EpiloguePtrArraySimtVectorized {}; struct NoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecialized {}; +struct PtrArrayNoSmemWarpSpecializedTransposed {}; struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 3aed3271..1ef06a53 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -34,6 +34,7 @@ #include #include #include +#include // cute::false_type ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -60,9 +61,12 @@ struct FusionOperation { static constexpr int AlignmentScalar = 0; static constexpr bool IsScaleFactorSupported = false; static constexpr bool IsPerRowScaleSupported = false; + static constexpr bool IsPerColScaleSupported = false; + using ElementBias = void; static constexpr int AlignmentBias = 0; static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsPerColBiasSupported = false; static constexpr bool IsDePerRowBiasSupported = false; using ActivationFn = void; @@ -190,6 +194,24 @@ struct LinCombPerRowBiasEltAct static constexpr bool IsEltActSupported = true; }; +// D = activation(alpha * acc + beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltAct + : LinCombPerColBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + // D = activation(alpha * acc + beta * C + per-row bias) // aux = alpha * acc + beta * C + per-row bias template< @@ -214,6 +236,30 @@ struct LinCombPerRowBiasEltActAux static constexpr bool IsAuxOutSupported = true; }; +// D = activation(alpha * acc + beta * C + per-col bias) +// aux = alpha * acc + beta * C + per-col bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltActAux + : LinCombPerColBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + // D = activation(per-row alpha * acc + per-row beta * C + per-row bias) template< template class ActivationFn_, @@ -233,6 +279,46 @@ struct PerRowLinCombPerRowBiasEltAct static constexpr bool IsPerRowScaleSupported = true; }; +// D = per-column alpha * per-row alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementCompute_, + class ElementScalar_ = ElementCompute_, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct OuterProdLinComb : FusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr auto RoundStyle = RoundStyle_; + static constexpr bool IsSourceSupported = true; + static constexpr bool IsPerRowScaleSupported = true; + static constexpr bool IsPerColScaleSupported = true; +}; + +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerColLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerColScaleSupported = true; +}; + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias // if D is fp8 // D = scale_d * activation(Z) @@ -254,6 +340,27 @@ struct ScaledLinCombPerRowBiasEltAct static constexpr bool IsScaleFactorSupported = true; }; +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias // if D is fp8 // amax_d = max(abs(elements in activation(Z))) @@ -291,6 +398,43 @@ struct ScaledLinCombPerRowBiasEltActAmaxAux static constexpr bool IsAuxOutSupported = true; }; +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltActAmaxAux + : ScaledLinCombPerColBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + // Z = Aux // dY = alpha * acc + beta * C // D = d_activation(dY, Z) diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index e028846a..3e57fa0b 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -708,6 +708,105 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(alpha * acc + beta * C + per-column bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltAct = + Sm90EVT, + Sm90LinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = activation(alpha * acc + beta * C + per-row bias) // Aux = alpha * acc + beta * C + per-row bias) template< @@ -832,6 +931,132 @@ struct FusionCallbacks< }; ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(alpha * acc + beta * C + per_col bias) +// Aux = alpha * acc + beta * C + per_col bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerColBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = per-row alpha * acc + per-row beta * C + per-row bias template< class CtaTileShapeMNK, @@ -954,6 +1179,133 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = per-col alpha * acc + per-col beta * C + per-column bias +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColLinCombPerColBiasEltAct = + Sm90EVT, + Sm90PerColLinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,bool,int64_t>; + using StrideBeta = Stride<_0,bool,int64_t>; + StrideAlpha dAlpha = {_0{}, bool(1), 0}; + StrideBeta dBeta = {_0{}, bool(1), 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + namespace detail { template @@ -1120,6 +1472,154 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerColBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias // if D is fp8 // amax_d = max(abs(elements in activation(Z))) @@ -1440,6 +1940,326 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z + +// fp8 aux specialization +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8 = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias + Sm90ScaledLinCombPerColBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +// non-fp8 aux specialization +// lets us use some EVT specializations such as relu + uint1b_t aux +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8 = + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90EVT, // Aux = Z + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerColBias + > + > + >, + Sm90ScalarBroadcast // scale_d + >; + +// dispatcher +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle + >, + Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > +>; + + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementAmax, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerColBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerColBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerColBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerColBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (detail::is_fp8_v) { + amax_D_ptr_ = amax_D_ptr; + } + + // Aux is fp8 -> DAG arguments + if constexpr (detail::is_fp8_v) { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + + // Aux is not fp8 -> Tree arguments + else { + return + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + { // unary op : store(Z) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias + }, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d},{scale_d_ptr}}, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + template< class CtaTileShapeMNK, class EpilogueTile, @@ -1679,6 +2499,87 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +// D = per-column alpha * per-row alpha * acc + beta * c +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentScalar = 128 / sizeof_bits_v, // Alignment of per-column and per-row scaling vectors + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90OuterProdLinComb = + Sm90EVT, // c(beta) * c(C) + c(alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // c(alpha) * c(acc) + Sm90OuterProduct<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, Stride<_0,_1,int>, AlignmentScalar>, // alpha_col * alpha_row + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + OuterProdLinComb, + CtaTileShapeMNK, + EpilogueTile +> : Sm90OuterProdLinComb { + using Impl = Sm90OuterProdLinComb; + using Operation = OuterProdLinComb; + + struct Arguments { + + // Give a name and flat ordering to the fusion callback args + using StrideCol = Stride<_1,_0,int>; + using StrideRow = Stride<_0,_1,int>; + using StrideBeta = Stride<_0,_0,int>; + ElementScalar const* alpha_ptr_col = nullptr; + ElementScalar const* alpha_ptr_row = nullptr; + ElementScalar beta = static_cast(0); + ElementScalar const* beta_ptr = nullptr; + StrideCol dAlphaCol = {}; + StrideRow dAlphaRow = {}; + StrideBeta dBeta = {}; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { + {beta, beta_ptr, dBeta}, // leaf args : beta + {}, // leaf args : C + { + { alpha_ptr_col, alpha_ptr_row, dAlphaCol, dAlphaRow }, // leaf args : alpha cols / rows + {}, // leaf args : acc + {} + }, + {} + }; + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // D = softmax(top_k(alpha * acc + beta * C)) template< int TopK, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 131d0ba5..321daa6b 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -266,8 +266,8 @@ struct Sm90TreeVisitor< auto const& scale_op = get<0>(Impl::ops); auto const& added_op = get<2>(Impl::ops); if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { - return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || - is_C_load_needed() || + return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || + is_C_load_needed() || added_op.is_producer_load_needed(); } else { @@ -408,8 +408,9 @@ template < > struct Sm90TreeVisitor< Sm90Compute, cutlass::epilogue::thread::ReLu> || - cute::is_same_v, cutlass::epilogue::thread::Clamp> >>, + cute::enable_if_t, cutlass::epilogue::thread::ReLu> || + cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU> >>, Sm90TreeVisitor< Sm90AuxStore< Stages, @@ -503,7 +504,8 @@ struct Sm90TreeVisitor< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { ElementCompute pre_relu = frg_compute[i]; - if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp>) { + if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU>) { frg_compute[i] = relu(frg_compute[i], params_compute); } else { diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index a22bed4e..66b1086e 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -734,11 +734,12 @@ private: // Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors template< class Element, - class StrideMNL = Stride<_0,_0,_0>, + class StrideMNL_ = Stride<_0,_0,_0>, int BroadcastCount = 1, template class ReductionFn = multiplies > struct Sm90ScalarBroadcastPtrArray { + using StrideMNL = StrideMNL_; static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); @@ -780,8 +781,8 @@ struct Sm90ScalarBroadcastPtrArray { CUTLASS_DEVICE bool is_producer_load_needed() const { - // producer load is needed if Element is not void and we have multiple scalars - return !cute::is_void_v and size<2>(params_ptr->dScalar[0]) != 0; + // producer load is needed if Element is not void + return !cute::is_void_v; } CUTLASS_DEVICE bool @@ -814,7 +815,7 @@ struct Sm90ScalarBroadcastPtrArray { CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar[0]) != 0) { + if (size<2>(params_ptr->dScalar[0]) != 0) { auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -1377,6 +1378,171 @@ struct Sm90ColBroadcast { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Do outer product from the column and row loaded +// +template< + int Stages, + class CtaTileShapeMNK, + class ElementScalar, + class StrideColMNL_ = Stride<_1,_0,int64_t>, /// NOTE: Batched scaling untested for now + class StrideRowMNL_ = Stride<_0,_1,int64_t>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = false // Fallback scalar broadcast for nullptr params +> +struct Sm90OuterProduct { + using StrideColMNL = StrideColMNL_; + using StrideRowMNL = StrideRowMNL_; + static_assert(Stages == 0, "OuterProduct doesn't support smem usage"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(!EnableNullptr, "Nullptr fallback not implemented"); + static_assert(is_static_v(StrideColMNL{}))> && + is_static_v(StrideRowMNL{}))>, "Only batch stride can be dynamic"); + static_assert(take<0,2>(StrideColMNL{}) == Stride<_1,_0>{} && + take<0,2>(StrideRowMNL{}) == Stride<_0,_1>{}, "Row and column incorrectly formatted"); + + // Accumulator distributes col/row elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + struct Arguments { + ElementScalar const* ptr_col = nullptr; + ElementScalar const* ptr_row = nullptr; + StrideColMNL dCol = {}; + StrideRowMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90OuterProduct() { } + + CUTLASS_HOST_DEVICE + Sm90OuterProduct(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorCol, class RTensorCol, + class GTensorRow, class RTensorRow + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensorCol&& tCgCol, RTensorCol&& tCrCol, + GTensorRow&& tCgRow, RTensorRow&& tCrRow, + Params const& params) + : tCgCol(cute::forward(tCgCol)) + , tCrCol(cute::forward(tCrCol)) + , tCgRow(cute::forward(tCgRow)) + , tCrRow(cute::forward(tCrRow)) + , params(params) {} + + GTensorCol tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensorCol tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensorRow tCgRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensorRow tCrRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + + CUTLASS_DEVICE void + begin() { + + // Filter so we don't issue redundant copies over stride-0 modes + copy(filter(tCgCol), filter(tCrCol)); + copy(filter(tCgRow), filter(tCrRow)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_colrow; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_colrow[i] = static_cast(tCrCol_mn(epi_v * FragmentSize + i) * tCrRow_mn(epi_v * FragmentSize + i)); + } + return frg_colrow; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mRow, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks< + decltype(tCgCol), decltype(tCrCol), + decltype(tCgRow), decltype(tCrRow) + >( + cute::move(tCgCol), cute::move(tCrCol), + cute::move(tCgRow), cute::move(tCrRow), + params + ); + } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // Batch matrix broadcast diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index f9ebe739..83cfc030 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -293,11 +293,11 @@ template < class LayoutOrStrideMNL, class SmemLayoutAtom, // Unused class CopyOpR2S, // Unused - int Alignment, + int Alignment, bool EnableNullptr > struct Sm90AuxStore< - 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, + 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr > { using ElementAux = Element; @@ -343,7 +343,7 @@ struct Sm90AuxStore< CUTLASS_HOST_DEVICE Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { } - + Params const* params_ptr; CUTLASS_DEVICE bool @@ -381,7 +381,7 @@ struct Sm90AuxStore< tC_cAux(cute::forward(tC_cAux)), problem_shape_mnl(problem_shape_mnl), params_ptr(params_ptr) {} - + GTensorR2G tC_gAux; RTensor tC_rAux; CTensorR2G tC_cAux; @@ -414,7 +414,7 @@ struct Sm90AuxStore< Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); - + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); @@ -451,7 +451,7 @@ struct Sm90AuxStore< // Predication support Tensor coordAux = make_identity_tensor(shape(mAux)); Tensor tC_cAux = sm90_partition_for_epilogue( - coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks( cute::move(tC_gAux), @@ -703,7 +703,6 @@ public: else if constexpr (FinalReduction) { auto problem_shape_mnkl = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -753,19 +752,18 @@ public: static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { -#if !defined(CUTLASS_SKIP_REDUCTION_INIT) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); if (args.ptr_row != nullptr) { return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); } return Status::kSuccess; } - else -#endif - if constexpr (FinalReduction) { + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -939,7 +937,7 @@ public: for (int v = 0; v < size(frg_A); ++v) { // Step1: swap if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second - swap(frg_A(v), frg_B(v)); + cutlass::swap(frg_A(v), frg_B(v)); } // Step2: shuffle @@ -1023,9 +1021,7 @@ public: } else { if (is_reduced_lane) { - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), recast(filter(tCgBuf))); + copy_aligned(tCrRow, recast(tCgBuf)); } } sync_fn(); @@ -1054,9 +1050,7 @@ public: } else { if (is_reduced_lane) { - // Filter so we don't issue redunant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrRow), filter(tCsBuf)); + copy_aligned(tCrRow, tCsBuf); } } sync_fn(); @@ -1296,7 +1290,6 @@ public: else if constexpr (FinalReduction) { auto problem_shape_mnkl = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_mnkl; - auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -1348,19 +1341,18 @@ public: static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { -#if !defined(CUTLASS_SKIP_REDUCTION_INIT) - auto problem_shape_mnkl = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_mnkl; if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); if (args.ptr_col != nullptr) { return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); } return Status::kSuccess; } - else -#endif - if constexpr (FinalReduction) { + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); @@ -1522,9 +1514,7 @@ public: using ElementGmem = cute::conditional_t; Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); if (is_reduced_lane) { - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrCol), recast(filter(tCgBuf))); + copy_aligned(tCrCol, recast(tCgBuf)); } sync_fn(); } @@ -1542,9 +1532,7 @@ public: // Dump warp reduction to smem workspace Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); if (is_reduced_lane) { - // Filter so we don't issue redunant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_aligned(filter(tCrCol), filter(tCsBuf)); + copy_aligned(tCrCol, tCsBuf); } sync_fn(); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 4f7d99fa..48f4756d 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -300,7 +300,6 @@ struct Sm90VisitorImplBase { tuple ops; }; - template struct Sm90VisitorImpl : Sm90VisitorImplBase { @@ -658,7 +657,6 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl(std::move(callbacks_tuple)); } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// template< diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 9f1cd774..186e9966 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -258,6 +258,54 @@ struct LeakyReLU > { } }; +// Y = min((X <= threshold ? 0 : X), upper_bound) +template +struct ThresholdReLU { + static constexpr bool kIsHeavy = false; + + struct Arguments { + T threshold = T(0); + T upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); + }; + + CUTLASS_HOST_DEVICE + T operator()(T value, T threshold, T upper_bound) const { + minimum_with_nan_propagation mn; + + return mn((value <= threshold ? T(0) : value), upper_bound); + } + + CUTLASS_HOST_DEVICE + T operator()(T value, Arguments const& args = Arguments()) const { + return operator()(value, args.threshold, args.upper_bound); + } +}; + +template +struct ThresholdReLU> { + static constexpr bool kIsHeavy = false; + + using Arguments = typename ThresholdReLU::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, T threshold, T upper_bound) const { + ThresholdReLU relu; + + Array retvals; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + retvals[i] = relu(values[i], threshold, upper_bound); + } + + return retvals; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const& values, Arguments const& args = Arguments()) const { + return operator()(values, args.threshold, args.upper_bound); + } +}; + // Tanh operator template struct Tanh { @@ -311,26 +359,7 @@ struct Sigmoid { }; template -struct Sigmoid > { - static const bool kIsHeavy = true; - - CUTLASS_HOST_DEVICE - Array operator()(Array const &value) const { - Array y; - Sigmoid sigmoid_op; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N; ++i) { - y[i] = sigmoid_op(value[i]); - } - - return y; - } -}; - -template -struct Sigmoid> { - using T = half_t; +struct Sigmoid> { static const bool kIsHeavy = true; CUTLASS_HOST_DEVICE @@ -450,6 +479,9 @@ struct HardSwish > { } }; +template +using ScaledHardSwish = Scale>; + // // GELU function definitions implemented as described by // Hendrycks, D., and Gimpel, K. in diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index f74a36af..c3aa3ff4 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -169,7 +169,7 @@ public: /// Constructs the function object, possibly loading from pointers in host memory CUTLASS_HOST_DEVICE - LinearCombination(Params const ¶ms, int group_idx = 0) { + explicit LinearCombination(Params const ¶ms, int group_idx) { if (params.alpha_ptr_array != nullptr && params.alpha_ptr_array[group_idx] != nullptr) { alpha_ = *(params.alpha_ptr_array[group_idx]); } @@ -190,6 +190,10 @@ public: } } + CUTLASS_HOST_DEVICE + explicit LinearCombination(const Params & params) + : LinearCombination(params, /* group_idx */ 0) { } + /// Returns true if source is needed CUTLASS_HOST_DEVICE bool is_source_needed() const { diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 48b66a14..4a0c67ba 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -39,11 +39,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -478,6 +474,12 @@ public: // Iterate over accumulator tile // + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wcuda-compat" + // Turn off clangs warning about loop unroll argument using parens. + #endif + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { @@ -531,6 +533,10 @@ public: destination_iterator.store(output_fragment); ++destination_iterator; } + + #ifdef __clang__ + #pragma clang diagnostic pop + #endif } }; diff --git a/include/cutlass/epilogue/threadblock/epilogue_base.h b/include/cutlass/epilogue/threadblock/epilogue_base.h index 6853f5f0..30432e80 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_base.h +++ b/include/cutlass/epilogue/threadblock/epilogue_base.h @@ -43,11 +43,7 @@ #include #endif -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/matrix_shape.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h index 43b14c35..486c0304 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h @@ -38,11 +38,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h b/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h index 2be1fa55..85ddae7c 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h +++ b/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h @@ -38,11 +38,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h index 9efbee47..aff05485 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h @@ -39,11 +39,11 @@ #pragma once -#if defined(__CUDACC_RTC__) #include + +#if defined(__CUDACC_RTC__) #include #else -#include #include #endif diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h b/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h index 9bae7a74..df5bbc5c 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h @@ -50,11 +50,11 @@ #pragma once -#if defined(__CUDACC_RTC__) #include + +#if defined(__CUDACC_RTC__) #include #else -#include #include #endif diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 7e6d2a69..d69f43c4 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -39,11 +39,11 @@ #pragma once -#if defined(__CUDACC_RTC__) #include + +#if defined(__CUDACC_RTC__) #include #else -#include #include #endif diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h index 1d4c7016..7f82bac7 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h @@ -39,11 +39,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "cutlass/array.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h b/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h index 259f0729..027830c2 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h @@ -303,6 +303,12 @@ public: // Pipeline Loop // + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wcuda-compat" + // Turn off clang warning about loop unroll argument using parens. + #endif + #pragma unroll(IterationsUnroll ? kIterations : 1) for (int iter_idx = 1; iter_idx < kIterations + 1; ++iter_idx) { @@ -377,8 +383,19 @@ public: callbacks.end_step(iter_idx-1); } + + #ifdef __clang__ + #pragma clang diagnostic pop + #endif + } else { + #ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wcuda-compat" + // Turn off clang warning about loop unroll argument using parens. + #endif + #pragma unroll(IterationsUnroll ? kIterations : 1) for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { @@ -459,6 +476,11 @@ public: callbacks.end_step(iter_idx); } + + #ifdef __clang__ + #pragma clang diagnostic pop + #endif + } callbacks.end_epilogue(); diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp index 7a332f11..28d482b7 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp @@ -335,7 +335,8 @@ struct VisitorAuxLoad{ template< class ThreadMap, class Element, - class StrideMNL + class StrideMNL, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct VisitorRowBroadcast { @@ -399,6 +400,16 @@ struct VisitorRowBroadcast { CUTLASS_DEVICE void begin_epilogue() { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_row == nullptr) { + auto tC_rRow_vec = recast>(coalesce(tC_rRow)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tC_rRow_vec); ++i) { + tC_rRow_vec[i].fill(params_ptr->null_default); + } + return; + } + } clear(tC_rRow); auto src_v = filter(tC_gRow); auto coord_v = filter(tC_cRow); @@ -406,7 +417,7 @@ struct VisitorRowBroadcast { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(src_v); ++i) { bool guard = get<1>(coord_v(i)) < n; - cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); + cutlass::arch::global_load(dst_v(i), (void const *)&src_v(i), guard); } } @@ -464,7 +475,8 @@ struct VisitorRowBroadcast { template< class ThreadMap, class Element, - class StrideMNL = Stride<_1,_0,_0> + class StrideMNL = Stride<_1,_0,_0>, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params > struct VisitorColBroadcast { @@ -523,6 +535,12 @@ struct VisitorColBroadcast { CUTLASS_DEVICE void begin_epilogue() { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_col == nullptr) { + fill(tC_rCol, params_ptr->null_default); + return; + } + } clear(tC_rCol); Tensor pred = make_tensor(shape(tC_gCol)); CUTLASS_PRAGMA_UNROLL diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp index 1c24e22d..dcec7ac8 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp @@ -519,10 +519,7 @@ struct VisitorRowReduction { // Guard against uses of the existing SMEM tile __syncthreads(); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rSrc); ++i) { - copy_vec(filter(tRS_rSrc), filter(tRS_sRows)); - } + copy(tRS_rSrc, tRS_sRows); __syncthreads(); diff --git a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index 617b8e39..8a88c0ab 100644 --- a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -391,7 +391,7 @@ struct OutputTileOptimalThreadMap { 1>; /// Initial offset function - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static MatrixCoord initial_offset(int thread_idx) { // int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); @@ -462,7 +462,7 @@ struct OutputTileOptimalThreadMap { static int const kThreads = Threads; /// Function to compute each thread's initial offset - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static MatrixCoord initial_offset(int thread_idx) { // int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h index c512dd87..3322a4c6 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h @@ -212,15 +212,23 @@ public: // When the optimization is enabled, small tiles require separate logic. bool kN32_optimization = (WarpShape::kN * Detail::kLanesInQuad * Policy::kElementsPerAccess * sizeof_bits::value) % 1024 == 0; if (kN32_optimization) { + int ptr_idx = ((warp_column_ * sizeof_bits::value) / 1024) % Detail::kPointerCount; + if (ptr_idx == 0) { ptr = pointers_[0]; } else if (ptr_idx == 1) { - ptr = pointers_[1]; + if constexpr (AccessType::kElements >= 2) { + ptr = pointers_[1]; + } } else if (ptr_idx == 2) { - ptr = pointers_[2]; + if constexpr (AccessType::kElements >= 3) { + ptr = pointers_[2]; + } } else if (ptr_idx == 3) { - ptr = pointers_[3]; + if constexpr (AccessType::kElements >= 4) { + ptr = pointers_[3]; + } } } diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index fa3873c5..4ca8e113 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -38,7 +38,7 @@ #include #include #endif - +#include #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/uint128.h" @@ -54,12 +54,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTLASS_HOST_DEVICE void swap(T &lhs, T &rhs) { - T tmp = lhs; - lhs = rhs; - rhs = tmp; -} +using ::cuda::std::swap; /****************************************************************************** * Static math utilities diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 38ea4008..cfb6b8bb 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -1053,8 +1053,8 @@ float_e5m2_t::float_e5m2_t(float_e4m3_t x) { /// datatype in runtime argument list. /// /// Currently supported runtime datatypes compatible with type_erased_dynamic_float8_t: -/// QMMAFormat::E5M2 -/// QMMAFormat::E4M3 +/// MXF8F6F4Format::E5M2 +/// MXF8F6F4Format::E4M3 /// /////////////////////////////////////////////////////////////// diff --git a/include/cutlass/floating_point_nvrtc.h b/include/cutlass/floating_point_nvrtc.h index fdbd80fc..c08396aa 100644 --- a/include/cutlass/floating_point_nvrtc.h +++ b/include/cutlass/floating_point_nvrtc.h @@ -35,6 +35,12 @@ #pragma once +#include // CUTLASS_HOST_DEVICE +#include // uint32_t +#if !defined(__CUDACC_RTC__) +#include // std::memcpy +#endif + namespace cutlass { /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 5b2bc3c6..3c4d5c76 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -50,7 +50,7 @@ #ifdef _MSC_VER // Provides support for alternate operators such as 'and', 'or', ... -#include +#include #endif // _MSC_VER namespace cutlass { diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 64e27a8d..f58fde88 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -35,6 +35,8 @@ #include "cutlass/pipeline/sm90_pipeline.hpp" #include "cutlass/gemm/collective/collective_mma_decl.hpp" #include "cutlass/gemm/collective/collective_builder_decl.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" // SM90 Collective Builders should be used only starting CUDA 12.0 #if (__CUDACC_VER_MAJOR__ >= 12) @@ -236,8 +238,9 @@ struct CollectiveBuilder< GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes; + + static constexpr int Sm90ReducedSmemCapacityBytes = + detail::sm90_smem_capacity_bytes; static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); @@ -343,7 +346,7 @@ public: return t; } else { - return cute::stride(t); + return cute::stride(t); } } @@ -415,15 +418,15 @@ public: static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; - static constexpr int PipelineStages = IsMixedInput ? - detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) : - detail::compute_stage_count_or_override(StageCountType{}); + static constexpr int PipelineStages = IsMixedInput ? + detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) + : detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaRmemAWarpSpecialized>; + MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput + , MainloopSm90TmaGmmaRmemAWarpSpecialized>; using SmemCopyAtomA = cute::conditional_t>; using SmemCopyAtomB = cute::conditional_t, void>; @@ -761,13 +764,13 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = cute::is_same_v ? 2 : 1; using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; - using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; - using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); @@ -867,13 +870,13 @@ struct CollectiveBuilder< static constexpr int NumLoadWarpGroups = 1; using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; - using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomA, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; - using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomB, NumThreadsPerWarpGroup * NumLoadWarpGroups, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); diff --git a/include/cutlass/gemm/collective/collective_builder_decl.hpp b/include/cutlass/gemm/collective/collective_builder_decl.hpp index c0570d37..c27a84f2 100644 --- a/include/cutlass/gemm/collective/collective_builder_decl.hpp +++ b/include/cutlass/gemm/collective/collective_builder_decl.hpp @@ -54,6 +54,18 @@ struct StageCountAutoCarveout { explicit StageCountAutoCarveout(cute::Int) {} }; +namespace detail { + +// Forward Declaration +template +constexpr int +compute_carveout_from_epi(); + +} // namespace detail + +template +struct StageCountAutoCarveoutEpi : StageCountAutoCarveout()> {}; + using StageCountAuto = StageCountAutoCarveout<0>; // Used to automatically let the builder pick the kernel schedule. diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 103da9af..21f8a557 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -41,9 +41,10 @@ #include "cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" -#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 00000000..ed223a56 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,1370 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule_, + class TileShape_, + class ElementAOptionalTuple, + class StrideA_, + class ElementBOptionalTuple, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ +public: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput; + using TileShape = TileShape_; + using KernelSchedule = KernelSchedule_; + +private: + template friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMma; + using Utils = detail::MixedInputUtils; + + // + // Type Aliases + // + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + using StrideScale = cute::Stride, int64_t, int64_t>; + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert(( IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)) || + (!IsATransformed && (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + ((cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value) && + (cutlass::gemm::detail::is_k_major() || is_layout::value || is_layout::value)), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + using SwappedSmemLayoutAtomA = cute::conditional_t; + using SwappedSmemLayoutAtomB = cute::conditional_t; + using SwappedSmemCopyAtomA = cute::conditional_t; + using SwappedSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using SwappedTransformA = cute::conditional_t; + using SwappedTransformB = cute::conditional_t; + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + using SmemLayoutAtomScale = Layout(SwappedSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SwappedSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + + /// Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomA{}, select<0,2>(TileShape{}), InternalSwappedStrideA{})); + using SmemLayoutB = decltype(detail::get_smem_layout(SwappedSmemLayoutAtomB{}, select<1,2>(TileShape{}), InternalSwappedStrideB{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,NonVoidStrideScale>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + // NOTE: Deleting this assertion without required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage { + CUTE_ALIGNAS(SmemAlignmentA) cute::ArrayEngine> smem_A; + CUTE_ALIGNAS(SmemAlignmentB) cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + struct TensorMapStorage { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_scale; + cute::TmaDescriptor smem_tensormap_zero; + }; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementScale const** ptr_S = nullptr; + NonVoidStrideScale const* dS{}; + int chunk_size = 0; + ElementZero const** ptr_Z = nullptr; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using LayoutA = decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideA{}, int32_t(0)), InternalSwappedStrideA{})); + using LayoutB = decltype(detail::get_gmem_layout(repeat_like(InternalSwappedStrideB{}, int32_t(0)), InternalSwappedStrideB{})); + + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), LayoutB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(detail::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + void* tensormaps; + SwappedElementA const** ptr_A; + SwappedStrideA ptr_dA; + SwappedElementB const** ptr_B; + SwappedStrideB ptr_dB; + NonVoidElementScale const** ptr_S; + NonVoidStrideScale const* dS; + NonVoidElementZero const** ptr_Z; + int64_t scale_k; + int chunk_size; + int reload_factor = (chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + + if constexpr (SwapAB) { + init_M = get<1>(init_shape); + init_N = get<0>(init_shape); + } + // Batches/Groups are managed by using appropriate pointers to input matrices + const uint32_t mock_L = 1; + SwappedElementA const* ptr_A_first_batch; + SwappedElementB const* ptr_B_first_batch; + SwappedStrideA ptr_dA; + SwappedStrideB ptr_dB; + InternalSwappedStrideA dA; + InternalSwappedStrideB dB; + + if constexpr (not SwapAB) { + ptr_A_first_batch = reinterpret_cast(args.ptr_A); + ptr_B_first_batch = reinterpret_cast(args.ptr_B); + } + else { + ptr_A_first_batch = reinterpret_cast(args.ptr_B); + ptr_B_first_batch = reinterpret_cast(args.ptr_A); + } + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + if constexpr (not SwapAB) { + ptr_dA = args.dA; + ptr_dB = args.dB; + } + else { + ptr_dA = args.dB; + ptr_dB = args.dA; + } + dA = InternalSwappedStrideA{}; + if constexpr (is_layout::value) { + dA = make_layout( + transform_leaf(dA.shape(), [](auto x){ + if constexpr (not is_static_v) { + return static_cast(1); + } else { + return x; + } + }), + dA.stride()); + } + dB = InternalSwappedStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + if constexpr (not SwapAB) { + dA = args.dA; + dB = args.dB; + } + else { + dA = args.dB; + dB = args.dA; + } + ptr_dA = SwappedStrideA{}; + ptr_dB = SwappedStrideB{}; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, detail::get_gmem_layout(make_shape(init_M,init_K,mock_L), dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, detail::get_gmem_layout(make_shape(init_N,init_K,mock_L), dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + typename Params::TMA_Scale tma_load_scale{}; + typename Params::TMA_Zero tma_load_zero{}; + + void* tensormaps = workspace; + auto args_setup = [&](auto ptr_A, auto ptr_B, int64_t scale_k = 0, int chunk_size = 0, int reload_factor = 1) -> Params { + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + tma_load_scale, + tma_load_zero, + tensormaps, + reinterpret_cast(ptr_A), + ptr_dA, + reinterpret_cast(ptr_B), + ptr_dB, + reinterpret_cast(args.ptr_S), + args.dS, + reinterpret_cast(args.ptr_Z), + scale_k, + chunk_size, + reload_factor, + dA, + dB + }; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return SwapAB ? args_setup(args.ptr_B, args.ptr_A) + : args_setup(args.ptr_A, args.ptr_B); + } + else if constexpr (ModeHasScales) { + // NOTE: fix chunk wise scaling + //auto scale_k = (K + args.chunk_size - 1) / args.chunk_size; + auto scale_k = 1; + ElementScale const* ptr_S = reinterpret_cast(args.ptr_S); + StrideScale dS{}; + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS)); + tma_load_scale = make_tma_copy( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + return SwapAB ? args_setup(args.ptr_B, args.ptr_A, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup(args.ptr_A, args.ptr_B, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = reinterpret_cast(args.ptr_Z); + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_layout(make_shape(init_M,scale_k,mock_L), dS)); + tma_load_zero = make_tma_copy( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + return SwapAB ? args_setup(args.ptr_B, args.ptr_A, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})) + : args_setup(args.ptr_A, args.ptr_B, scale_k, args.chunk_size, (args.chunk_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{})); + + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + + // Calculating workspace size + auto calculate_workspace_size = [SizeOfCuTensorMap, sm_count](uint32_t num_input_tensors) { + return num_input_tensors * SizeOfCuTensorMap * sm_count; + }; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return calculate_workspace_size(2); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale tensormap copies + return calculate_workspace_size(3); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies, followed by scale and zeros tensormap copies + return calculate_workspace_size(4); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in get_workspace_size."); + } + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto get_stride = [](auto stride) { + if constexpr (cute::is_pointer_v>) { + return *stride; + } + else { + return stride; + } + }; + auto dA = get_stride(args.dA); + auto dB = get_stride(args.dB); + implementable = implementable && cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(M,K,L), dA)); + implementable = implementable && cutlass::detail::check_alignment(detail::get_gmem_layout(cute::make_shape(N,K,L), dB)); + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; + const int scale_k = (K + args.chunk_size - 1) / args.chunk_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.chunk_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = Utils::compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = Utils::compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytesExtra = Utils::compute_tma_transaction_bytes_extra(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK + TmaTransactionBytesExtra; + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(M,K,mock_L), mainloop_params.dA))); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(shape(detail::get_gmem_layout(make_shape(N,K,mock_L), mainloop_params.dB))); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } + else if constexpr (ModeHasScales) { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class... Ts, + class... TMs, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + static_assert(sizeof... (TMs) == 2, "Direct convert needs two tensormaps"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + static_assert(sizeof... (TMs) == 3, "Scaled convert needs three tensormaps"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + static_assert(sizeof... (TMs) == 4, "Scaled and zero convert needs four tensormaps"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + auto extra_input_partitions = Utils::partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + } + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with chunk_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int scale_load_k = *k_tile_iter / mainloop_params.reload_factor; // This will always be 0 when chunk_size == K. + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_scale.with(get<2>(input_tensormaps), *tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + if (cute::elect_one_sync()) { + copy(mainloop_params.tma_load_zero.with(get<3>(input_tensormaps), *tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SwappedSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(SwappedSmemLayoutAtomB{}) == 2, "SwappedSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = [&]{ + if constexpr (not is_layout::value) { + // Make register tensor with MMA layout + return make_fragment_like(tCrA_mma); + } + else { + // Make register tensor matching smem layout, converter will take care of de-swizzling + return make_tensor_like(tCsA(_,_,_,Int<0>{})); + } + }(); + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(SwappedSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = Utils::partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = Utils::retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + constexpr int K_WAIT_MAX = cute::min(K_BLOCK_MAX - 1, 7); + static_assert(K_BLOCK_MAX >= 4, "Consider increasing TileShapeK"); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + if (K_BLOCK_MAX > 1) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, read_stage); + } + + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for the first mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + + // NOTE: Check this when applying swizzling PR on top of GGMD + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + + warpgroup_wait(); + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + + // NOTE: Check this when applying swizzling PR on top of GGMD + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); + } + else { + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) { + Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 2, read_stage); + } + if (k_block < K_BLOCK_MAX - 1) { + Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, k_block + 1); + } + } + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + cute::TmaDescriptor* tma_desc_scale = &gmem_tensormap[sm_idx + 2*sm_count]; + cute::TmaDescriptor* tma_desc_zero = &gmem_tensormap[sm_idx + 3*sm_count]; + + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + Tensor pS_tensormap = make_tensor(mainloop_params.tma_load_scale.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sS_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_scale), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pS_tensormap), recast(sS_tensormap)); + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor pZ_tensormap = make_tensor(mainloop_params.tma_load_zero.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sZ_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_zero), Int<1>{}, Int<1>{}); + if (cute::elect_one_sync()) { + copy(recast(pZ_tensormap), recast(sZ_tensormap)); + } + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + + __syncwarp(); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::make_tuple(tma_desc_a, tma_desc_b, tma_desc_scale, tma_desc_zero); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_init."); + } + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_scale, + mainloop_params.ptr_S[next_batch]); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_zero, + mainloop_params.ptr_Z[next_batch]); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_address."); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + cute::array prob_shape_scale = {1,1,1,1,1}; + cute::array prob_stride_scale = {0,0,0,0,0}; + cute::array prob_shape_zero = {1,1,1,1,1}; + cute::array prob_stride_zero = {0,0,0,0,0}; + + SwappedElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, detail::get_gmem_layout(make_shape(M,K,Int<1>{}), mainloop_params.ptr_dA[next_group])); + + SwappedElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, detail::get_gmem_layout(make_shape(N,K,Int<1>{}), mainloop_params.ptr_dB[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + NonVoidElementScale const* ptr_S = nullptr; + // NOTE: figure out chunk wise scaling. auto scale_k = (K + mainloop_params.chunk_size - 1) / mainloop_params.chunk_size; + auto scale_k = 1; + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale, + prob_shape_scale, prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + ElementZero const* ptr_Z = nullptr; + // NOTE: figure out chunk wise scaling. auto scale_k = (K + mainloop_params.chunk_size - 1) / mainloop_params.chunk_size; + auto scale_k = 1; + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero, + prob_shape_zero, prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_scale) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_zero) { + stride = (stride * sizeof_bits_v) / 8; + } + + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_scale, + prob_shape_scale, + prob_stride_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_zero, + prob_shape_zero, + prob_stride_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_replace_global_tensor_properties."); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + tma_descriptor_cp_fence_release(get<2>(input_tensormaps), shared_tensormaps.smem_tensormap_scale); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + tma_descriptor_cp_fence_release(get<3>(input_tensormaps), shared_tensormaps.smem_tensormap_zero); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_cp_fence_release."); + } + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::tma_descriptor_fence_acquire(get<2>(input_tensormaps)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::tma_descriptor_fence_acquire(get<3>(input_tensormaps)); + } + else if constexpr (KernelConversionMode != ConversionMode::DirectConvert){ + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_fence_acquire."); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 628750fc..5264aa4c 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -374,7 +374,7 @@ struct CollectiveMma< // Prepare the TMA loads for A and B // - constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; Tensor gA_mkl = get<0>(load_inputs); diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 73564d3c..5c6c2a0f 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -85,13 +85,40 @@ class GemmUniversalAdapter; ////////////////////////////// CUTLASS 3.x API ///////////////////////////////// //////////////////////////////////////////////////////////////////////////////// +namespace detail { + +// Work-around for some DispatchPolicy types not having a Stages member. +// In that case, the Stages value is 0. Most code should static_assert +// that the number of stages is valid. + +// Whether DispatchPolicy::Stages is valid. +// It should also be convertible to int, but if not, that will show up +// as a build error when GemmUniversalAdapter attempts to assign it to kStages. +template +struct has_Stages : cute::false_type {}; + +template +struct has_Stages> : cute::true_type {}; + +template +constexpr int stages_member(DispatchPolicy) { + if constexpr (has_Stages::value) { + return DispatchPolicy::Stages; + } + else { + return 0; + } +} + +} // namespace detail + template class GemmUniversalAdapter< GemmKernel_, - cute::enable_if_t::value>> + cute::enable_if_t>::value>> { public: - using GemmKernel = GemmKernel_; + using GemmKernel = GetUnderlyingKernel_t; using TileShape = typename GemmKernel::TileShape; using ElementA = typename GemmKernel::ElementA; using ElementB = typename GemmKernel::ElementB; @@ -158,7 +185,7 @@ public: CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; - static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + static int constexpr kStages = detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); // Inspect TiledCopy for A and B to compute the alignment size static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< @@ -336,7 +363,7 @@ public: } /// Primary run() entry point API that is static allowing users to create and manage their own params. - /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() + /// Supplied params struct must be construct by calling GemmKernel::to_underlying_arguments() static Status run(Params& params, cudaStream_t stream = nullptr, @@ -358,10 +385,10 @@ public: [[maybe_unused]] constexpr bool is_static_1x1x1 = cute::is_static_v and cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; - dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), - cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); - void* kernel_params[] = {¶ms}; + [[maybe_unused]] dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + [[maybe_unused]] void* kernel_params[] = {¶ms}; if constexpr (kEnableCudaHostAdapter) { // @@ -377,13 +404,23 @@ public: #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with CUDA host adapter"); #endif - launch_result = cuda_adapter->launch(grid, - cluster, - block, - smem_size, - stream, - kernel_params, - 0); + if constexpr (is_static_1x1x1) { + launch_result = cuda_adapter->launch(grid, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } } else { CUTLASS_TRACE_HOST("GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA host adapter is null"); @@ -392,8 +429,10 @@ public: } else { CUTLASS_ASSERT(cuda_adapter == nullptr); - void const* kernel = (void const*) device_kernel; - if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { + [[maybe_unused]] void const* kernel = (void const*) device_kernel; + static constexpr bool kClusterLaunch = GemmKernel::ArchTag::kMinComputeCapability == 90 + ; + if constexpr (kClusterLaunch) { if constexpr (is_static_1x1x1) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("GemmUniversal::run: Launching static 1x1x1 kernel"); @@ -526,11 +565,11 @@ public: template class GemmUniversalAdapter< GemmKernel_, - cute::enable_if_t::value>> + cute::enable_if_t>::value>> { public: - using GemmKernel = GemmKernel_; + using GemmKernel = GetUnderlyingKernel_t; static bool const kInternalTranspose = !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index fa275bda..236a1227 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -105,7 +105,8 @@ struct KernelCpAsyncWarpSpecializedPingpong { }; struct KernelCpAsyncWarpSpecializedCooperative { }; struct KernelTma { }; struct KernelTmaWarpSpecialized { }; -struct KernelTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedPingpong { +}; struct KernelTmaWarpSpecializedCooperative { }; @@ -247,6 +248,7 @@ struct MainloopSm90TmaGmmaRmemAWarpSpecialized { "KernelSchedule must be one of the warp specialized policies"); }; + template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, @@ -310,6 +312,7 @@ struct MainloopSm90TmaGmmaWarpSpecializedSparse { using Schedule = KernelSchedule; }; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/group_array_problem_shape.hpp b/include/cutlass/gemm/group_array_problem_shape.hpp index 4a90a1d0..fbc0fdd7 100644 --- a/include/cutlass/gemm/group_array_problem_shape.hpp +++ b/include/cutlass/gemm/group_array_problem_shape.hpp @@ -69,7 +69,7 @@ struct GroupProblemShape { CUTLASS_HOST_DEVICE UnderlyingProblemShape const get_host_problem_shape(int32_t group_idx) const { - return host_problem_shapes[group_idx]; + return host_problem_shapes != nullptr ? host_problem_shapes[group_idx] : UnderlyingProblemShape{}; } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h b/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h new file mode 100644 index 00000000..3b7b126a --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/kernel/gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementA_, ElementB_, ElementC_, + ElementAccumulator>::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void + > +struct DefaultGemmGroupedPerGroupScale; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout +> +struct DefaultGemmGroupedPerGroupScale< + ElementA, + LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + GroupScheduleMode_, + Operator, + SharedMemoryClear, + PermuteDLayout, + typename platform::enable_if< ! cutlass::is_complex::value>::type +> { + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments< + ElementA, + LayoutA, + ComplexTransform::kNone, + kAlignmentA, + ElementB, + LayoutB, + ComplexTransform::kNone, + kAlignmentB, + LayoutC, + kInternalTranspose + >; + + // Define the default GEMM kernel + using DefaultGemmKernel = typename kernel::DefaultGemm< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementC, + typename MapArguments::LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + true, + Operator, + SharedMemoryClear, + false, /*GatherA*/ + false, /*GatherB*/ + false, /*ScatterD*/ + PermuteDLayout + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmGroupedPerGroupScale< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Complex-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear + > +struct DefaultGemmGroupedPerGroupScale< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + GroupScheduleMode_, + Operator, + SharedMemoryClear, + layout::NoPermute, /*PermuteDLayout*/ + typename platform::enable_if::value>::type +> { + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments< + ElementA, + LayoutA, + TransformA, + kAlignmentA, + ElementB, + LayoutB, + TransformB, + kAlignmentB, + LayoutC, + kInternalTranspose + >; + + using DefaultGemmKernel = typename kernel::DefaultGemmComplex< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + ElementC, + typename MapArguments::LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MapArguments::kTransformA, + MapArguments::kTransformB, + Operator, + false + >::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::GemmGroupedPerGroupScale< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, + ThreadblockSwizzle, + GroupScheduleMode_, + kInternalTranspose + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/ell_gemm.h b/include/cutlass/gemm/kernel/ell_gemm.h index 7cd61980..aad32959 100644 --- a/include/cutlass/gemm/kernel/ell_gemm.h +++ b/include/cutlass/gemm/kernel/ell_gemm.h @@ -691,7 +691,7 @@ struct EllGemm { static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - constexpr bool is_double = (sizeof(Mma::IteratorA::Element) == 8); + constexpr bool is_double = (sizeof(typename Mma::IteratorA::Element) == 8); constexpr bool is_multiple_alignment = (kAlignmentA > 1) && (kAlignmentB > 1) && (kAlignmentC > 1); const bool is_specialized_blocksize = @@ -699,11 +699,11 @@ struct EllGemm { && params.ell_blocksize >= Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add if ((is_double || is_multiple_alignment) && is_specialized_blocksize) { - mma.operator()( + mma.template operator()( gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); } else { - mma.operator()( + mma.template operator()( gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, ell_iterator); } } diff --git a/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h b/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h new file mode 100644 index 00000000..972681ab --- /dev/null +++ b/include/cutlass/gemm/kernel/gemm_grouped_per_group_scale.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Problem visitor for grouped GEMMs +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform + bool Transposed = false +> +struct GemmGroupedPerGroupScale : + public GemmGrouped { + + // Inherit constructors + using Base = GemmGrouped; + + // Inherit type definitions + using typename Base::Mma; + using typename Base::Epilogue; + using typename Base::EpilogueOutputOp; + using typename Base::ThreadblockSwizzle; + using typename Base::Params; + using typename Base::SharedStorage; + + // Explicitly inherit the kTransposed constant + static bool const kTransposed = Base::kTransposed; + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + typename Base::ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, + 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA *ptr_A = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB *ptr_B = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{ + 0, + threadblock_offset.n() + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), + ptr_A, + {problem_size.m(), problem_size.k()}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), + ptr_B, + {problem_size.k(), problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + ElementC *ptr_C = params.ptr_C[problem_idx]; + ElementC *ptr_D = params.ptr_D[problem_idx]; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset.mn() + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size.mn(), + thread_idx, + threadblock_offset.mn() + ); + + Epilogue epilogue( + shared_storage.kernel.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // The if branch is for the per-group scaling epilogue. The customized epilogue operator scales each gemm output by a scalar value. + // This branch is only enabled if EpilogueOutputOp is LinearCombination. + if constexpr (platform::is_same>::value) + { + EpilogueOutputOp output_op(params.output_op, problem_idx); + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + } else { + EpilogueOutputOp output_op(params.output_op); + // Execute the epilogue operator to update the destination tensor. + epilogue( + output_op, + iterator_D, + accumulators, + iterator_C); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h index 304f23e7..1c4411bd 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/gemm_grouped_problem_visitor.h @@ -68,7 +68,7 @@ struct GemmGroupedProblemSizeHelper { CUTLASS_HOST_DEVICE static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) { if (kTransposed) { - swap(problem.m(), problem.n()); + cutlass::swap(problem.m(), problem.n()); } } diff --git a/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h index cdb82599..5d8ce789 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h @@ -437,7 +437,7 @@ protected: int m_begin = tile_work.tiled_coord.m() * Mma::Shape::kM; int m_end = params.block_mapping.problem_size.m(); - return Mma::IteratorA( + return typename Mma::IteratorA( params.params_A, ptr_A, { m_end, tile_work.k_end }, @@ -466,7 +466,7 @@ protected: int n_begin = tile_work.tiled_coord.n() * Mma::Shape::kN; int n_end = params.block_mapping.problem_size.n(); - return Mma::IteratorB( + return typename Mma::IteratorB( params.params_B, ptr_B, { tile_work.k_end, n_end }, diff --git a/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/include/cutlass/gemm/kernel/grouped_problem_visitor.h index 31787372..4df76ec0 100644 --- a/include/cutlass/gemm/kernel/grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/grouped_problem_visitor.h @@ -66,10 +66,10 @@ struct BaseGroupedProblemVisitor { int32_t problem_idx; int32_t problem_start; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE ProblemInfo() : problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {} - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE ProblemInfo(int32_t problem_idx_, int32_t problem_start_) : problem_idx(problem_idx_), problem_start(problem_start_) {} }; diff --git a/include/cutlass/gemm/kernel/params_universal_base.h b/include/cutlass/gemm/kernel/params_universal_base.h index 86986f2e..172855ed 100644 --- a/include/cutlass/gemm/kernel/params_universal_base.h +++ b/include/cutlass/gemm/kernel/params_universal_base.h @@ -182,7 +182,7 @@ struct UniversalParamsBase CUTLASS_TRACE_HOST(" Initialize " << workspace_bytes << " workspace bytes"); cudaError_t result = cudaMemsetAsync( - semaphore, + static_cast(workspace), 0, workspace_bytes, stream); diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h index 6b36db21..e8383faf 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -479,14 +479,14 @@ public: // Construct iterators to A and B operands for Mma1 typename Mma1::IteratorA iterator_A( - Mma1::IteratorA::Params(ldm_A), + typename Mma1::IteratorA::Params(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_MxK); typename Mma1::IteratorB iterator_BT( - Mma1::IteratorB::Params(ldm_B), + typename Mma1::IteratorB::Params(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, @@ -494,14 +494,14 @@ public: // Construct iterators to A and B operands for Mma2 typename Mma2::IteratorA iterator_B( - Mma2::IteratorA::Params(ldm_B), + typename Mma2::IteratorA::Params(ldm_B), ptr_B, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_MxK); typename Mma2::IteratorB iterator_AT( - Mma2::IteratorB::Params(ldm_A), + typename Mma2::IteratorB::Params(ldm_A), ptr_A, {problem_size_k, problem_size.n()}, thread_idx, @@ -560,7 +560,7 @@ public: // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), ptr_C, problem_size.mn(), thread_idx, @@ -570,7 +570,7 @@ public: // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), ptr_D, problem_size.mn(), thread_idx, @@ -634,7 +634,7 @@ public: // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldc[problem_idx]), ptr_C, problem_size.mn(), thread_idx, @@ -644,7 +644,7 @@ public: // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), + typename Epilogue::OutputTileIterator::Params(params.ldd[problem_idx]), ptr_D, problem_size.mn(), thread_idx, diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h index 2e31c778..054d2a73 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h @@ -357,7 +357,7 @@ struct Rank2KGroupedProblemVisitor : public GroupedProblemVisitor< int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); if (kFillModeC == cutlass::FillMode::kUpper) { - swap(macro_row, macro_col); + cutlass::swap(macro_row, macro_col); } int32_t row = OffsetHelper::macro_row_to_row(macro_row, threadblock_id); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 823e919e..c0c10b97 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -218,11 +218,6 @@ public: uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -231,6 +226,11 @@ public: workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + TileSchedulerParams scheduler; if constexpr (IsGroupedGemmKernel) { scheduler = TileScheduler::to_underlying_arguments( @@ -276,10 +276,6 @@ public: size_t workspace_size = 0; constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { @@ -294,6 +290,10 @@ public: workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; } @@ -306,23 +306,25 @@ public: constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); static constexpr uint32_t NumAccumulatorMtxs = 1; - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } @@ -633,7 +635,7 @@ public: constexpr bool IsEpiLoad = true; if (work_tile_info.is_valid()) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -644,7 +646,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } load_order_barrier.wait(); @@ -667,7 +669,7 @@ public: auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; @@ -697,7 +699,7 @@ public: // tensormap update { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -708,7 +710,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } } @@ -738,7 +740,7 @@ public: if (work_tile_info.is_valid()) { if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -749,8 +751,8 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, - epi_store_tensormap, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + epi_store_tensormap, consumer_warp_group_idx); } } @@ -805,7 +807,7 @@ public: params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { @@ -843,7 +845,7 @@ public: problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -854,7 +856,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 38633764..1b7c0cb4 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -226,11 +226,6 @@ public: uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); @@ -239,6 +234,11 @@ public: workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, sm_count); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means // subtile will not be used, therefore separate reduction will not be enabled. @@ -288,10 +288,6 @@ public: size_t workspace_size = 0; constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; if (sm_count <= 0) { @@ -306,6 +302,10 @@ public: workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; } @@ -318,6 +318,20 @@ public: constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); static constexpr uint32_t NumAccumulatorMtxs = 1; + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + status = TileScheduler::template initialize_workspace( args.scheduler, workspace_ptr + workspace_offset, stream, typename ProblemShape::UnderlyingProblemShape{}, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); workspace_offset += TileScheduler::template get_workspace_size( @@ -326,19 +340,6 @@ public: if (status != Status::kSuccess) { return status; } - - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - if (status != Status::kSuccess) { - return status; - } - return status; } @@ -666,7 +667,7 @@ public: constexpr bool IsEpiLoad = true; if (work_tile_info.is_valid()) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -677,7 +678,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } load_order_barrier.wait(); @@ -700,7 +701,7 @@ public: auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_load_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_load_tensormap); } bool wait = work_tile_info.is_valid() && curr_batch != next_work_tile_info.L_idx; @@ -730,7 +731,7 @@ public: // tensormap update { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_load_tensormap, @@ -741,7 +742,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_load_tensormap, 0); } } @@ -771,7 +772,7 @@ public: if (work_tile_info.is_valid()) { if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -782,7 +783,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); } @@ -844,7 +845,7 @@ public: params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); if (did_batch_change) { - collective_epilogue.tensormaps_fence_acquire(epi_store_tensormap); + collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { @@ -897,7 +898,7 @@ public: problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), 1); } if (warp_idx_in_warp_group == 0) { - collective_epilogue.tensormaps_perform_update( + collective_epilogue.template tensormaps_perform_update( shared_storage.tensormaps.epilogue, params.epilogue, epi_store_tensormap, @@ -908,7 +909,7 @@ public: // Converge before issuing tensormap fence release since fence is aligned __syncwarp(); - collective_epilogue.tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, + collective_epilogue.template tensormaps_cp_fence_release(shared_storage.tensormaps.epilogue, epi_store_tensormap, consumer_warp_group_idx); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 243a9e70..0dece139 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -51,8 +51,6 @@ namespace cutlass::gemm::kernel { -/////////////////////////////////////////////////////////////////////////////// - template < class ProblemShape_, class CollectiveMainloop_, @@ -107,7 +105,6 @@ public: TileShape, ClusterShape >::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; @@ -122,7 +119,8 @@ public: static constexpr uint32_t NumMmaWarpGroups = NumMMAThreads / NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - + static constexpr uint32_t NumFixupBarriers = NumMmaWarpGroups; + /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; static constexpr uint32_t MmaRegisterRequirement = 232; @@ -207,22 +205,23 @@ public: uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* mainloop_workspace = nullptr; // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means // subtile will not be used, therefore separate reduction will not be enabled. constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles + ); return { args.mode, @@ -254,13 +253,12 @@ public: size_t workspace_size = 0; constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); return workspace_size; } @@ -273,17 +271,17 @@ public: constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); static constexpr uint32_t NumAccumulatorMtxs = 1; - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; @@ -314,6 +312,7 @@ public: operator()(Params const& params, char* smem_buf) { using namespace cute; using X = Underscore; + #if defined(__CUDA_ARCH_FEAT_SM90_ALL) # define ENABLE_SM90_KERNEL_LEVEL 1 #endif @@ -487,7 +486,6 @@ public: // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); collective_mainloop.load( @@ -581,11 +579,10 @@ public: auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - // Allocate the accumulators for the (M,N) blk_shape // // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { collective_mainloop.mma( mainloop_pipeline, diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index cf4a552c..c19a8e9f 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -105,14 +105,24 @@ public: static_assert(!cute::is_same_v, "Ping-pong kernel does not currently support stream-K scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + TileSchedulerTag, + ArchTag, + TileShape, + ClusterShape + >::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; using TileSchedulerParams = typename TileScheduler::Params; - + // Warp specialization thread count per threadblock + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp for C static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = 2; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t NumMMAThreads = size(TiledMma{}); // 4 warp + static constexpr uint32_t MaxThreadsPerBlock = NumMMAThreads * NumMmaWarpGroups + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static_assert(NumMMAThreads == 128, "Pingpong kernel must have TiledMMA operating using 128 threads."); + static_assert(MaxThreadsPerBlock == 384, "Pingpong kernel must have 384 threads in total."); /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; @@ -142,7 +152,7 @@ public: alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; } pipelines; - + struct TensorStorage : cute::aligned_struct<128, _1> { using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -208,16 +218,17 @@ public: uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - void* epilogue_workspace = workspace_ptr + workspace_offset; workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + void* mainloop_workspace = nullptr; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); return { args.mode, @@ -225,7 +236,9 @@ public: CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace) + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles + ) }; } @@ -247,13 +260,14 @@ public: static size_t get_workspace_size(Arguments const& args) { size_t workspace_size = 0; - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + return workspace_size; } @@ -266,17 +280,17 @@ public: static constexpr uint32_t NumEpilogueSubTiles = 1; static constexpr uint32_t NumAccumulatorMtxs = 1; - status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; } - status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles, NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; @@ -308,9 +322,12 @@ public: using namespace cute; using X = Underscore; +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); #else // Preconditions @@ -350,6 +367,7 @@ public: CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); } + // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; @@ -450,8 +468,8 @@ public: auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); TileScheduler scheduler{params.scheduler}; - if (warp_group_role == WarpGroupRole::Consumer1) { + // Advance 2nd Math WG to the next work tile for the startup scheduler.advance_to_next_work(); // Advance 2nd Math WG pipeline states to the end of 1st Math WG @@ -466,7 +484,7 @@ public: if (warp_group_role == WarpGroupRole::Producer) { cutlass::arch::warpgroup_reg_dealloc(); - + // Mainloop Producer Warp if (producer_warp_role == ProducerWarpRole::Mainloop) { // Ensure that the prefetched kernel does not touch @@ -546,6 +564,7 @@ public: // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End } // Producer Warp Group End @@ -564,7 +583,7 @@ public: return; } #endif - + while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index 5e61e7c9..08437c70 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -29,8 +29,8 @@ * **************************************************************************************************/ #pragma once -#include "cutlass/gemm/kernel/static_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/static_tile_scheduler.hpp" namespace cutlass::gemm::kernel::detail { diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index 888be276..a30d9ce0 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -337,12 +337,16 @@ public: uint64_t blk_per_grid_dim = divmod_cluster_shape_minor.divide(linear_idx - group_info.start_linear_idx); divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); + // With static schedulers, we launch grid such that all cluster are linear (1-D) order, i.e., + // there can only be one cluster in the minor dimension. get_grid_shape() in scheduler params + // put cluster_shape.m/n() as the minor dimension based on raster order AlongN/M resp. + // Therefore, the offset of a CTA (inside a cluster) in the minor dimension can be directly be + // inferred by the blockIdx along the minor dimension. if (raster_order == RasterOrder::AlongN) { - cluster_minor_offset = cta_m_in_cluster; + cluster_minor_offset = blockIdx.x; } else { - cluster_minor_offset = cta_n_in_cluster; + cluster_minor_offset = blockIdx.y; } uint64_t cluster_idx_minor, cluster_idx_major; diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index 80b374ad..b5e62164 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -58,7 +58,9 @@ private: using UnderlyingArguments = typename UnderlyingScheduler::Arguments; using UnderlyingParams = typename UnderlyingScheduler::Params; + dim3 block_id_in_cluster_; uint64_t current_work_linear_idx_ = 0; + uint32_t unit_iter_start_ = 0; public: @@ -240,25 +242,26 @@ public: CUTLASS_HOST_DEVICE PersistentTileSchedulerSm90StreamK() { }; - CUTLASS_HOST_DEVICE - PersistentTileSchedulerSm90StreamK(Params const& params_) : scheduler_params(params_) { + CUTLASS_DEVICE + PersistentTileSchedulerSm90StreamK(Params const& params_) : scheduler_params(params_), block_id_in_cluster_(cute::block_id_in_cluster()) { if (params_.raster_order_ == RasterOrder::AlongN) { current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); } else { current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); } + } CUTLASS_DEVICE WorkTileInfo - get_current_work() const { - return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params); + get_current_work() { + return get_current_work_for_linear_idx(unit_iter_start_, current_work_linear_idx_, block_id_in_cluster_, scheduler_params); } CUTLASS_DEVICE static WorkTileInfo - get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) { + get_current_work_for_linear_idx(uint32_t &unit_iter_start, uint64_t linear_idx, dim3 block_id_in_cluster, Params const& params) { // The maximum number of work units is units_per_problem_ * splits_. // The multiplication by splits_ is used for handling split-K, in which // units_per_problem_ is equal to the total number of output tiles. To account @@ -271,7 +274,7 @@ public: } WorkTileInfo work_tile_info; - assign_work(params, linear_idx, work_tile_info); + assign_work(params, linear_idx, block_id_in_cluster, work_tile_info, unit_iter_start); return work_tile_info; } @@ -283,13 +286,15 @@ public: bool continue_current_work(WorkTileInfo& work_tile_info) const { return continue_current_work_for_linear_idx( - current_work_linear_idx_, work_tile_info, scheduler_params); + current_work_linear_idx_, unit_iter_start_, block_id_in_cluster_, work_tile_info, scheduler_params); } CUTLASS_DEVICE static bool continue_current_work_for_linear_idx( uint64_t linear_idx, + uint32_t unit_iter_start, + dim3 block_id_in_cluster, WorkTileInfo& work_tile_info, Params const& params) { @@ -298,7 +303,7 @@ public: if (work_tile_info.k_tile_remaining == 0) { return false; } - assign_work(params, linear_idx, work_tile_info); + fast_assign_work(unit_iter_start, params, linear_idx, block_id_in_cluster, work_tile_info); return work_tile_info.is_valid(); } @@ -316,9 +321,11 @@ public: return false; } return not get_current_work_for_linear_idx( + unit_iter_start_, current_work_linear_idx_ + ( uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count) ), + block_id_in_cluster_, scheduler_params ).is_valid(); } @@ -420,22 +427,24 @@ public: uint64_t reduction_tile_idx = tile_idx; uint64_t num_peers = 0; uint64_t reduction_peer_offset = 0; - if (params.requires_separate_reduction()) { + if ( + params.requires_separate_reduction() + ) { // If separate reduction is to be performed, each stream-K unit writes its partials // to a separate portion of the workspace. There are as many of these portions as there // are peers for a given output tile, so we multiply the tile index by the maximum peer count. - auto [first_peer_id, my_peer_id, last_peer_id] = tile_peer_range(params, tile_idx, static_cast(work_tile_info.K_idx)); + auto [first_peer_id, my_peer_id, last_peer_id] = tile_peer_range(params, tile_idx, work_tile_info); + auto peer_id_in_output_tile = my_peer_id - first_peer_id; num_peers = last_peer_id - first_peer_id + 1; - reduction_tile_idx *= Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); - reduction_peer_offset = my_peer_id * cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}); + reduction_tile_idx = tile_idx * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); + reduction_peer_offset = peer_id_in_output_tile * cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * num_accumulator_mtxs; } // Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood. // Thus, the start of the reduction space is the same across all threads in a warp group. - uint64_t reduction_offset = - (static_cast(cute::size<0>(TileShape{})) * static_cast(cute::size<1>(TileShape{})) * reduction_tile_idx * num_accumulator_mtxs) + - reduction_peer_offset + + uint64_t reduction_offset_base = (static_cast(cute::size<0>(TileShape{})) * static_cast(cute::size<1>(TileShape{})) * reduction_tile_idx * num_accumulator_mtxs) + (static_cast(size(accumulators)) * barrier_idx * BarrierManager::ThreadCount); + uint64_t reduction_offset = reduction_offset_base + reduction_peer_offset; ElementAccumulator* group_reduction_workspace = reinterpret_cast(params.reduction_workspace_) + reduction_offset; @@ -457,7 +466,9 @@ public: if (params.divmod_splits_.divisor > 1) { reduction_tiles = params.units_per_problem_; } - else if (params.requires_separate_reduction()) { + else if ( + params.requires_separate_reduction() + ) { reduction_tiles = params.sk_tiles_ * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); } else { @@ -470,29 +481,17 @@ public: reinterpret_cast(params.reduction_workspace_) + reduction_workspace_size); if (work_tile_info.is_reduction_unit()) { - plus add_fragments; - uint64_t peer_offset = size(accumulators) * num_barriers * BarrierManager::ThreadCount; - // Wait until the peers collaborating on this output tile have all written // their accumulators to workspace. BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, num_peers); - // Load the first peer's data - BlockStripedReduceT::load(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx); - - for (uint64_t i = 1; i < num_peers; ++i) { - // Load peer fragment - AccumulatorArrayT addend_fragment; - auto peer_reduction_workspace = reinterpret_cast(group_reduction_workspace + (i * peer_offset)); - - BlockStripedReduceT::load(addend_fragment, peer_reduction_workspace, barrier_group_thread_idx); - - // Add peer fragment - *accumulator_array = add_fragments(*accumulator_array, addend_fragment); - } + separate_reduction(accumulators, num_barriers, group_reduction_workspace, barrier_group_thread_idx, num_peers, num_accumulator_mtxs); } else if (!compute_epilogue(work_tile_info, params)) { - if (params.requires_separate_reduction() || work_tile_info.K_idx == 0) { + if ( + params.requires_separate_reduction() + || work_tile_info.K_idx == 0 + ) { // The first peer initializes the workspace partials in the non-separate-reduction case, // and all peers write to their own location in workspace when using separate reduction BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); @@ -513,12 +512,16 @@ public: BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment); } else { - if (params.reduction_mode_ == ReductionMode::Deterministic) { + if ( + params.reduction_mode_ == ReductionMode::Deterministic + ) { + // Wait until the preceding split added its accumulators BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + } else { - // Wait unitl the first split has stored its accumulators + // Wait until the first split has stored its accumulators BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1); } @@ -528,6 +531,36 @@ public: } } + template + CUTLASS_DEVICE + static void + separate_reduction( + FrgTensorC& accumulators, + uint32_t num_barriers, + typename FrgTensorC::value_type* reduction_workspace, + uint32_t thread_idx, + uint64_t num_peers, + uint32_t num_accumulator_mtxs) { + using AccumulatorArrayT = Array; + using BlockStripedReduceT = BlockStripedReduce; + + AccumulatorArrayT* accumulator_array = reinterpret_cast(accumulators.data()); + + plus add_fragments; + uint64_t peer_offset = cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * num_accumulator_mtxs; + + for (uint64_t i = 0; i < num_peers; ++i) { + // Load peer fragment + AccumulatorArrayT addend_fragment; + auto peer_reduction_workspace = reinterpret_cast(reduction_workspace + (i * peer_offset)); + + BlockStripedReduceT::load(addend_fragment, peer_reduction_workspace, thread_idx); + + // Add peer fragment + *accumulator_array = add_fragments(*accumulator_array, addend_fragment); + } + } + // Returns whether the block assigned this work should compute the epilogue for the corresponding // output tile. For the case of stream-K, this should only occur if the work is marked as the final split. CUTLASS_HOST_DEVICE @@ -587,6 +620,7 @@ public: args.max_swizzle_size, args.raster_order, args.decomposition_mode, + args.reduction_mode, mma_warp_groups, sizeof_bits::value, sizeof_bits::value, @@ -627,6 +661,7 @@ public: args.max_swizzle_size, args.raster_order, args.decomposition_mode, + args.reduction_mode, mma_warp_groups, sizeof_bits::value, sizeof_bits::value, @@ -668,224 +703,235 @@ public: return get_current_work(); } -private: - // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info - // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining - // iterations) is used to find the next tile in the current work unit. + // Given raster order and current work tile linear index, reset cta m and n index in the cluster. CUTLASS_DEVICE - static void - assign_work( + static dim3 + get_current_work_cta_m_n_in_cluster( + Params const& params, + uint64_t linear_idx, + dim3 block_id_in_cluster) { + auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = block_id_in_cluster; + uint64_t cta_m_in_cluster = static_cast(cta_m_in_cluster_); + uint64_t cta_n_in_cluster = static_cast(cta_n_in_cluster_); + return {static_cast(cta_m_in_cluster), static_cast(cta_n_in_cluster), _}; + } + +private: + + CUTLASS_DEVICE + static uint32_t + get_current_work_iter_start_possible_update_work_tile_k_remaining( Params const& params, uint64_t linear_idx, WorkTileInfo& work_tile_info) { + // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K + // threadblock individually. For the most part, the set of K iterations corresponding to stream-K + // work was divided amongst stream-K threadblocks, and a threadblock determined which tile + // it would compute a (potentially-partial) output tile for based on the space of k iterations + // assigned to it. This often results in stream-K threadblocks processing tiles with different + // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the + // (generally few) waves of threadblocks assigned to compute stream-K work. + // + // With the introduction of threadblock clusters, there is additional benefit to maintaining + // locality in the K dimension: shared portions of operands can be multicasted to threadblocks + // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to + // threadblocks respects the ability to perform multicasting. + // + // To do so, we divide up the linearized stream-K units into clusters and share the same K + // offsets for work within clusters. + uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); - auto [cta_m_in_cluster_, cta_n_in_cluster_, _] = cute::block_id_in_cluster(); - uint64_t cta_m_in_cluster = static_cast(cta_m_in_cluster_); - uint64_t cta_n_in_cluster = static_cast(cta_n_in_cluster_); - uint64_t output_tile_id = linear_idx; - if (linear_idx >= params.units_per_problem_ * params.divmod_splits_.divisor) { - // Separate-reduction work - auto cluster_size = params.get_cluster_size(); - // Divide up the linearized separate reduction units into clusters - uint64_t cluster_linear_reduction_unit_idx = params.div_cluster_size((linear_idx - params.units_per_problem_)); - uint64_t cluster_tile_idx, epi_subtile_idx; - params.divmod_epilogue_subtile_(cluster_tile_idx, epi_subtile_idx, cluster_linear_reduction_unit_idx); - // Bring the linearized tile ID back into the space of tiles, rather than clusters - output_tile_id = cluster_tile_idx * cluster_size; + uint64_t group_idx; + params.divmod_sk_groups_(cluster_linear_work_idx, group_idx, cluster_linear_work_idx); - work_tile_info.setup_separate_reduction(epi_subtile_idx); + // Determine whether we are in a "big group" that will process an additional + // stream-K cluster tile. + uint64_t sk_cluster_tiles = params.div_cluster_size(params.sk_tiles_); + uint64_t sk_cluster_tiles_in_group = params.divmod_sk_groups_.divide(sk_cluster_tiles); + if (group_idx < params.big_groups_) { + ++sk_cluster_tiles_in_group; } - else if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) { - // Data-parallel work - output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; - work_tile_info.K_idx = 0; - work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; - work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; + + // Determine whether we are in a "big unit" within the group, that will process + // an additional K chunk in the group. + uint64_t sk_tiles_in_group = sk_cluster_tiles_in_group * params.get_cluster_size(); + uint64_t k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor; + uint64_t k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group); + uint64_t big_units_in_group = params.div_cluster_size( + k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor)); + + uint64_t split; + params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); + + bool is_split_k = params.divmod_splits_.divisor > 1; + uint64_t big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx; + uint64_t big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; + uint64_t linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; + uint64_t k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group; + + // Determine the starting k iteration computed by this stream-K work unit + uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + + (k_tiles_per_split * split); + + // Adjust the starting position and number of k iterations for "big units," which + // compute one extra iteration. If there are any big units, they will be the first + // in the linearized ID space. + auto k_tiles_in_my_split = k_tiles_per_split; + if (big_unit_cmp_lhs < big_unit_cmp_rhs) { + // Since the "big units" are the first units in the linearized ID space, each + // of the units preceding this big unit computed one extra iteration. Thus, + // we must offset our start iteration by the number of units that precede + // the current unit in the linearized ID space. + unit_iter_start += big_unit_cmp_lhs; + ++k_tiles_in_my_split; } else { - // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K - // threadblock individually. For the most part, the set of K iterations corresponding to stream-K - // work was divided amongst stream-K threadblocks, and a threadblock determined which tile - // it would compute a (potentially-partial) output tile for based on the space of k iterations - // assigned to it. This often results in stream-K threadblocks processing tiles with different - // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the - // (generally few) waves of threadblocks assigned to compute stream-K work. - // - // With the introduction of threadblock clusters, there is additional benefit to maintaining - // locality in the K dimension: shared portions of operands can be multicasted to threadblocks - // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to - // threadblocks respects the ability to perform multicasting. - // - // To do so, we divide up the linearized stream-K units into clusters and share the same K - // offsets for work within clusters. - - uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); - - uint64_t group_idx; - params.divmod_sk_groups_(cluster_linear_work_idx, group_idx, cluster_linear_work_idx); - - // Determine whether we are in a "big group" that will process an additional - // stream-K cluster tile. - uint64_t sk_cluster_tiles = params.div_cluster_size(params.sk_tiles_); - uint64_t sk_cluster_tiles_in_group = params.divmod_sk_groups_.divide(sk_cluster_tiles); - if (group_idx < params.big_groups_) { - ++sk_cluster_tiles_in_group; + // Increment by one for each of the big clusters (since all big units precede this unit) + unit_iter_start += big_unit_cmp_rhs; + } + if (!is_split_k) { + // Adjust the unit starting position and number of tiles to avoid + // computing splits of size less than min_iters_per_sk_unit_ + int unused, start_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start); + if (start_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another + // stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take over these K tiles. + unit_iter_start -= start_tile_k_tile; + k_tiles_in_my_split += start_tile_k_tile; } - - // Determine whether we are in a "big unit" within the group, that will process - // an additional K chunk in the group. - uint64_t sk_tiles_in_group = sk_cluster_tiles_in_group * params.get_cluster_size(); - uint64_t k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor; - uint64_t k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group); - uint64_t big_units_in_group = params.div_cluster_size( - k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor)); - - uint64_t split; - params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); - - bool is_split_k = params.divmod_splits_.divisor > 1; - uint64_t big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx; - uint64_t big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; - uint64_t linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; - uint64_t k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group; - - // Determine the starting k iteration computed by this stream-K work unit - uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + - (k_tiles_per_split * split); - - // Adjust the starting position and number of k iterations for "big units," which - // compute one extra iteration. If there are any big units, they will be the first - // in the linearized ID space. - auto k_tiles_in_my_split = k_tiles_per_split; - if (big_unit_cmp_lhs < big_unit_cmp_rhs) { - // Since the "big units" are the first units in the linearized ID space, each - // of the units preceding this big unit computed one extra iteration. Thus, - // we must offset our start iteration by the number of units that precede - // the current unit in the linearized ID space. - unit_iter_start += big_unit_cmp_lhs; - ++k_tiles_in_my_split; + else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. + auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile); + unit_iter_start += adjustment_tiles; + k_tiles_in_my_split -= adjustment_tiles; } - else { - // Increment by one for each of the big clusters (since all big units precede this unit) - unit_iter_start += big_unit_cmp_rhs; + else if (params.ktile_start_alignment_count_ == 2 && start_tile_k_tile % 2 != 0) { + // ktile for each SM start from even number + // If start from odd number ktile within the output tile + // now start at the ktile one before my initial ktile start (take one ktile from prev sm) + // if end on odd number ktile within the output tile + // now end at ktile that one before my ktile end (give one ktile to next sm) + unit_iter_start -= 1; + k_tiles_in_my_split += 1; } + } + if (work_tile_info.k_tile_count == 0) { + // This is a new unit if (!is_split_k) { - // Adjust the unit starting position and number of tiles to avoid + // + // Adjust the unit ending position and number of tiles to avoid // computing splits of size less than min_iters_per_sk_unit_ - int unused, start_tile_k_tile; - params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start); - if (start_tile_k_tile < Params::min_iters_per_sk_unit_) { - // Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another - // stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to take over these K tiles. - unit_iter_start -= start_tile_k_tile; - k_tiles_in_my_split += start_tile_k_tile; - } - else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { - // Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // + + // Begin by assuming that no adjustment is needed + auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split; + + int unused, end_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end); + + if (end_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile, // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. - auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile); - unit_iter_start += adjustment_tiles; - k_tiles_in_my_split -= adjustment_tiles; + k_tiles_in_my_split -= end_tile_k_tile; } - else if (params.ktile_start_alignment_count == 2 && start_tile_k_tile % 2 != 0) { + else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take on these K tiles. + k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); + } + else if (params.ktile_start_alignment_count_ == 2 && end_tile_k_tile % 2 != 0) { // ktile for each SM start from even number // If start from odd number ktile within the output tile // now start at the ktile one before my initial ktile start (take one ktile from prev sm) - // if end on odd number ktile within the output tile + // If end on odd number ktile within the output tile, // now end at ktile that one before my ktile end (give one ktile to next sm) - unit_iter_start -= 1; - k_tiles_in_my_split += 1; + k_tiles_in_my_split -= 1; } } - if (work_tile_info.k_tile_count == 0) { - // This is a new unit - - if (!is_split_k) { - // - // Adjust the unit ending position and number of tiles to avoid - // computing splits of size less than min_iters_per_sk_unit_ - // - - // Begin by assuming that no adjustment is needed - auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split; - - int unused, end_tile_k_tile; - params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end); - - if (end_tile_k_tile < Params::min_iters_per_sk_unit_) { - // Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile, - // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. - k_tiles_in_my_split -= end_tile_k_tile; - } - else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { - // Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, - // which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. - // Adjust our work to take on these K tiles. - k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); - } - else if (params.ktile_start_alignment_count == 2 && end_tile_k_tile % 2 != 0) { - // ktile for each SM start from even number - // If start from odd number ktile within the output tile - // now start at the ktile one before my initial ktile start (take one ktile from prev sm) - // If end on odd number ktile within the output tile, - // now end at ktile that one before my ktile end (give one ktile to next sm) - k_tiles_in_my_split -= 1; - } - } - - work_tile_info.k_tile_remaining = k_tiles_in_my_split; - } - - uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; - - // Find the output tile corresponding to the final k tile covered by this - // work unit. Stream-K work units will work backwards in terms of the tiles they - // are responsible computing. This is beneficial because the final (partial) - // tile computed by a stream-K block is typically the beginning of the output - // tile, while the beginning (partial) tile is typically the ending of another - // output tile. Since ending portions of an output tile must reduce across - // other work units computing portions of that output tile, it is preferable - // for them to be computed later, so as to reduce the likelihood of blocking - // on other work. - - auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); - uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor; - uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; - - // Convert the output tile from the linearized space within each group to the - // overall linearized space. - output_tile_id = (output_tile_id_in_group * params.divmod_sk_groups_.divisor) + group_idx; - - // Bring the linearized tile ID back into the space of tiles, rather than clusters - output_tile_id *= params.get_cluster_size(); - - // The final linearized tile ID is in units of the cluster dimension over which we rasterize. - if (params.raster_order_ == RasterOrder::AlongN) { - output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } - else { - output_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } - - // The unit's starting k iteration in the current tile is either the starting - // iteration for the tile as a whole, or the starting k iteration for the unit - // as a whole (if the latter is greater than the former). - uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start); - - // Similarly, the unit's ending k iteration (exclusive) is either the end of - // the current tile it is assigned, or the ending iteration of the unit as a whole - // (if the latter is less than the former). - uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1); - - // Set the k offset to be the starting k tile for this output tile - work_tile_info.K_idx = static_cast(tile_iter_start - output_tile_iter_start); - work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; + work_tile_info.k_tile_remaining = k_tiles_in_my_split; } + return unit_iter_start; + } + + // Update output tile index given existing remaining k tiles of current work tile. + CUTLASS_DEVICE + static uint64_t update_output_tile_id_and_work_tile_k( + Params const& params, + WorkTileInfo& work_tile_info, + uint64_t linear_idx, + uint32_t unit_iter_start, + uint64_t cta_m_in_cluster, + uint64_t cta_n_in_cluster) { + // we divide up the linearized stream-K units into clusters and share the same K + // offsets for work within clusters. + uint64_t cluster_linear_work_idx = params.div_cluster_size(linear_idx); + + uint64_t unused, group_idx; + params.divmod_sk_groups_(unused, group_idx, cluster_linear_work_idx); + + uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; + + // Find the output tile corresponding to the final k tile covered by this + // work unit. Stream-K work units will work backwards in terms of the tiles they + // are responsible computing. This is beneficial because the final (partial) + // tile computed by a stream-K block is typically the beginning of the output + // tile, while the beginning (partial) tile is typically the ending of another + // output tile. Since ending portions of an output tile must reduce across + // other work units computing portions of that output tile, it is preferable + // for them to be computed later, so as to reduce the likelihood of blocking + // on other work. + + auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); + uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor; + uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; + + // Convert the output tile from the linearized space within each group to the + // overall linearized space. + uint64_t output_tile_id = (output_tile_id_in_group * params.divmod_sk_groups_.divisor) + group_idx; + + // Bring the linearized tile ID back into the space of tiles, rather than clusters + output_tile_id *= params.get_cluster_size(); + + // The final linearized tile ID is in units of the cluster dimension over which we rasterize. + if (params.raster_order_ == RasterOrder::AlongN) { + output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; + } + else { + output_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; + } + // The unit's starting k iteration in the current tile is either the starting + // iteration for the tile as a whole, or the starting k iteration for the unit + // as a whole (if the latter is greater than the former). + uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start); + + // Similarly, the unit's ending k iteration (exclusive) is either the end of + // the current tile it is assigned, or the ending iteration of the unit as a whole + // (if the latter is less than the former). + uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1); + + // Set the k offset to be the starting k tile for this output tile + work_tile_info.K_idx = static_cast(tile_iter_start - output_tile_iter_start); + work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; + + return output_tile_id; + } + // Given output tile index, update M, N, L index of current work tile info. + CUTLASS_DEVICE + static void + update_work_tile_m_n_l( + Params const& params, + uint32_t output_tile_id, + WorkTileInfo& work_tile_info, + uint64_t cta_m_in_cluster, + uint64_t cta_n_in_cluster) { uint64_t work_idx_l, remainder; params.divmod_batch_(work_idx_l, remainder, output_tile_id); @@ -907,18 +953,81 @@ private: work_tile_info.L_idx = static_cast(work_idx_l); } + // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info + // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining + // iterations) is used to find the next tile in the current work unit. + CUTLASS_DEVICE + static void + assign_work( + Params const& params, + uint64_t linear_idx, + dim3 block_id_in_cluster, + WorkTileInfo& work_tile_info, + uint32_t &unit_iter_start) { + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = + get_current_work_cta_m_n_in_cluster(params, linear_idx, block_id_in_cluster); + + uint64_t output_tile_id = linear_idx; + if (linear_idx >= params.units_per_problem_ * params.divmod_splits_.divisor) { + // Separate-reduction work + auto cluster_size = params.get_cluster_size(); + // Divide up the linearized separate reduction units into clusters + uint64_t cluster_linear_reduction_unit_idx = params.div_cluster_size((linear_idx - params.units_per_problem_)); + uint64_t cluster_tile_idx, epi_subtile_idx; + params.divmod_epilogue_subtile_(cluster_tile_idx, epi_subtile_idx, cluster_linear_reduction_unit_idx); + // Bring the linearized tile ID back into the space of tiles, rather than clusters + output_tile_id = cluster_tile_idx * cluster_size; + + work_tile_info.setup_separate_reduction(epi_subtile_idx); + } + else if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) { + // Data-parallel work + output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; + work_tile_info.K_idx = 0; + work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; + work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; + } + else { + unit_iter_start = get_current_work_iter_start_possible_update_work_tile_k_remaining(params, linear_idx, work_tile_info); + output_tile_id = update_output_tile_id_and_work_tile_k(params, work_tile_info, + linear_idx, unit_iter_start, cta_m_in_cluster, cta_n_in_cluster); + } + update_work_tile_m_n_l(params, output_tile_id, work_tile_info, cta_m_in_cluster, cta_n_in_cluster); + } + + // The fast path to get current output tile index then update fields of work tile info + // when continuing current work tile is needed, since k tile starting index has precomputed + // in the first time fetching current work tile. + CUTLASS_DEVICE + static void + fast_assign_work( + uint32_t unit_iter_start, + Params const& params, + uint64_t linear_idx, + dim3 block_id_in_cluster, + WorkTileInfo& work_tile_info) { + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = + get_current_work_cta_m_n_in_cluster(params, linear_idx, block_id_in_cluster); + + uint64_t output_tile_id = update_output_tile_id_and_work_tile_k(params, work_tile_info, + linear_idx, unit_iter_start, cta_m_in_cluster, cta_n_in_cluster); + + update_work_tile_m_n_l(params, output_tile_id, work_tile_info, cta_m_in_cluster, cta_n_in_cluster); + } + // Returns the starting and ending peer ID of this tile CUTLASS_HOST_DEVICE static auto - tile_peer_range(Params const& params, uint32_t tile_idx, uint32_t cur_k_tile) { + tile_peer_range(Params const& params, uint32_t tile_idx, WorkTileInfo const& work_tile_info) { + uint32_t cur_k_tile = static_cast(work_tile_info.K_idx); uint32_t tile_idx_in_cluster_path = params.div_cluster_size(tile_idx); uint32_t start_k_tile = params.divmod_tiles_per_output_tile_.divisor * tile_idx_in_cluster_path; uint32_t end_k_tile = start_k_tile + params.divmod_tiles_per_output_tile_.divisor - 1; uint32_t big_unit_k_tiles = params.big_units_ * (params.divmod_k_tiles_per_sk_unit_.divisor + 1); - auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t k_tiles_per_unit) { - uint32_t unit_k_start = unit_idx * k_tiles_per_unit; - uint32_t unit_k_end = unit_k_start + k_tiles_per_unit; + auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t unit_k_start, uint32_t unit_k_end) { if (k_tile - start_k_tile < Params::min_iters_per_sk_unit_ && unit_k_end - start_k_tile < Params::min_iters_per_sk_unit_) { // k_tile is within the first min_iters_per_sk_unit_ K tiles of this output tile, @@ -943,17 +1052,22 @@ private: if (k_tile < big_unit_k_tiles) { // The tile is within the "big unit range" uint32_t unit_idx = params.divmod_k_tiles_per_sk_big_unit_.divide(k_tile); - return static_cast(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_big_unit_.divisor)); + uint32_t unit_k_start = unit_idx * params.divmod_k_tiles_per_sk_big_unit_.divisor; + uint32_t unit_k_end = unit_k_start + params.divmod_k_tiles_per_sk_big_unit_.divisor; + return static_cast(adjust_unit(k_tile, unit_idx, unit_k_start, unit_k_end)); } else { // The tile is after the "big unit range." Account for this by finding the "normal unit" // that it belongs to, and then offsetting by the number of big units - uint32_t unit_idx = params.divmod_k_tiles_per_sk_unit_.divide(k_tile - big_unit_k_tiles) + params.big_units_; - return static_cast(adjust_unit(k_tile, unit_idx, params.divmod_k_tiles_per_sk_unit_.divisor)); + uint32_t unit_idx_after_big_units = params.divmod_k_tiles_per_sk_unit_.divide(k_tile - big_unit_k_tiles); + uint32_t unit_k_start = unit_idx_after_big_units * params.divmod_k_tiles_per_sk_unit_.divisor + (params.big_units_ * params.divmod_k_tiles_per_sk_big_unit_.divisor); + uint32_t unit_k_end = unit_k_start + params.divmod_k_tiles_per_sk_unit_.divisor; + uint32_t unit_idx = unit_idx_after_big_units + params.big_units_; + return static_cast(adjust_unit(k_tile, unit_idx, unit_k_start, unit_k_end)); } }; - return cute::make_tuple(find_unit(start_k_tile), find_unit(cur_k_tile), find_unit(end_k_tile)); + return cute::make_tuple(find_unit(start_k_tile), find_unit(start_k_tile + cur_k_tile), find_unit(end_k_tile)); } }; diff --git a/include/cutlass/gemm/kernel/tile_scheduler.hpp b/include/cutlass/gemm/kernel/tile_scheduler.hpp index 2d9b63ff..ba6b4243 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -37,15 +37,11 @@ #include "cutlass/arch/arch.h" #include "cutlass/detail/dependent_false.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm { -//////////////////////////////////////////////////////////////////////////////// - // // Tags for specifying tile schedulers // @@ -56,10 +52,12 @@ struct StreamKScheduler { }; struct GroupScheduler { }; // Only used for Grouped GEMMs +} // namespace cutlass::gemm //////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass::gemm - +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel::detail { diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 0972731c..da8794bb 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -50,6 +50,26 @@ namespace detail { //////////////////////////////////////////////////////////////////////////////// + CUTLASS_HOST_DEVICE + static uint32_t + get_max_cta_occupancy( + int max_sm_per_gpc, + GemmCoord cluster_shape, + int sm_count) { + // Provided SM count could possibly be less than the assumed maximum SMs per GPC + auto cluster_size = cluster_shape.m() * cluster_shape.n(); + int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; + int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); + int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; + + // The calculation below allows for larger grid size launch for different GPUs. + int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; + int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); + cta_per_device += max_cta_occupancy_per_residual_gpc; + + cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; + return cta_per_device; + } // // Parameters for SM90 tile schedulers // @@ -247,20 +267,7 @@ struct PersistentTileSchedulerSm90Params { * Hence, maximum SMs per GPC = 18 */ constexpr int max_sm_per_gpc = 18; - // Provided SM count could possibly be less than the assumed maximum SMs per GPC - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; - int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); - cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; - - // The calculation below allows for larger grid size launch for different GPUs. - int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; - int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); - cta_per_device += max_cta_occupancy_per_residual_gpc; - - if (sm_count < cta_per_device) { - cta_per_device = sm_count; - } + cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); if (raster_order == RasterOrder::AlongN) { launch_grid.y = possibly_truncate( cta_per_device / cluster_shape.m(), @@ -467,7 +474,7 @@ struct PersistentTileSchedulerSm90StreamKParams { static constexpr uint32_t max_sk_groups_ = 8u; // ktile start from even for each cta - uint32_t ktile_start_alignment_count { 1u }; + uint32_t ktile_start_alignment_count_ { 1u }; // Divides dividend by the cluster size CUTLASS_HOST_DEVICE @@ -519,7 +526,7 @@ struct PersistentTileSchedulerSm90StreamKParams { ReductionMode reduction_mode, DecompositionMode decomposition_mode, void* workspace, - const uint32_t epilogue_subtile = 1 + const uint32_t epilogue_subtile = 1u ) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl( problem_shape, tile_shape, cluster_shape); @@ -559,6 +566,15 @@ struct PersistentTileSchedulerSm90StreamKParams { void* workspace, const uint32_t epilogue_subtile = 1 ) { + + #if !defined(__CUDACC_RTC__) + if (hw_info.sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + #endif // !defined(__CUDACC_RTC__) + UnderlyingParams underlying_params; underlying_params.initialize( problem_blocks, @@ -568,115 +584,43 @@ struct PersistentTileSchedulerSm90StreamKParams { raster_order_option ); - auto problem_blocks_l = problem_blocks.z; + // Set basic parameters that not affected by any heuristics in advance. + set_params_base(underlying_params, workspace); - auto problem_blocks_m = round_up(problem_blocks.x, (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); - auto problem_blocks_n = round_up(problem_blocks.y, (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); - uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; - - // Reduction workspace is at the beginning of the workspace. Lock workspace follows. - void* reduction_workspace = workspace; - - if (decomposition_mode == DecompositionMode::SplitK || - (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { - // Short circuit to basic split-K decomposition - - // Don't split by more than the available number of SMs - if (splits > hw_info.sm_count) { - splits = hw_info.sm_count; - } - - // Don't split by more than the K tile iterations - // - // splits is almost certainly nonnegative here (e.g., hw_info.sm_count, - // despite being an int, is a count), so it can safely be converted to unsigned - // in the comparison to avoid a signed-unsigned comparison warning-as-error. - if (static_cast(splits) > k_tiles_per_output_tile) { - splits = k_tiles_per_output_tile; - } - - // If splits == k_tiles_per_output_tiles, there will be one k_tile per cta - // and this violate k_tile start from even requirements. Thus we need to - // reduce the number of splits. - if (ktile_start_alignment_count > 1u && - static_cast(splits) == k_tiles_per_output_tile) { - splits = k_tiles_per_output_tile / ktile_start_alignment_count; - } - - set_params_basic( - underlying_params, - problem_blocks_m, - problem_blocks_n, - problem_blocks_l, - splits, - k_tiles_per_output_tile, - reduction_workspace, - reduction_mode - ); - return; - } - - // Calculate the maximum number of blocks from clusters of shape cluster_shape that we - // can fit within sm_count SMs. - dim3 grid = get_grid_shape( + // Call for internal streamk heuristic to setup streamk related params + stream_k_heuristic( + underlying_params, problem_blocks, + k_tiles_per_output_tile, cluster_shape, hw_info, + splits, max_swizzle, - raster_order_option - ); + raster_order_option, + decomposition_mode, + reduction_mode, + epilogue_subtile + ); + } + + // max_sk_groups_ unless this extends beyond the extent of the dimension over + // which the problem is rasterized. For example, if the tiled problem shape + // (in CTA_M x CTA_N representation) when using 1x1 clusters is 4x16, + // and we rasterize along the M dimension, we choose 4 groups, rather than 8. + // If the cluster shape is 2x1, we choose 2 groups (CTA_M / CLUSTER_M). + uint32_t calculate_groups( + UnderlyingParams underlying_params, + ReductionMode reduction_mode, + uint32_t problem_blocks_m, + uint32_t problem_blocks_n, + GemmCoord cluster_shape, + uint64_t cluster_size, + uint32_t sk_tiles, + uint64_t sk_cluster_tiles, + uint64_t sk_units, + uint32_t k_tiles_per_output_tile, + bool do_separate_reduction) { - uint64_t ctas_per_wave = grid.x * grid.y; - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. - uint32_t sk_tiles = get_num_sk_tiles( - output_tiles, - ctas_per_wave, - cluster_size, - k_tiles_per_output_tile, - decomposition_mode - ); - uint64_t dp_tiles = output_tiles - sk_tiles; - - // Calculate the number of work units covering the data-parallel and stream-K tiles. - // A "work unit" is a single index in the linearized ID space used by the scheduler. - // We distinguish it from a "block," which is typically tied to a hardware unit - // (e.g., the callers into this scheduler will be persistent thread blocks). - // A work unit can encompass multiple output tiles worth of work (as will be the - // case for stream-K blocks). - // Since splitting is not required for data-parallel tiles, only one data-parallel unit - // is needed per data-parallel tile. - uint64_t dp_units = dp_tiles; - - uint64_t ctas_per_sk_wave = ctas_per_wave; - uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); - - if (decomposition_mode == DecompositionMode::DataParallel || - (decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) || - sk_units == 0) { - // Short circuit to basic data-parallel decomposition - set_params_basic( - underlying_params, - problem_blocks_m, - problem_blocks_n, - problem_blocks_l, - /* splits = */ 1, - k_tiles_per_output_tile, - reduction_workspace, - reduction_mode - ); - return; - } - - bool do_separate_reduction = should_perform_separate_reduction( - epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave); - - // Determine the number of stream-K groups that will be used. We currently use - // max_sk_groups_ unless this extends beyond the extent of the dimension over - // which the problem is rasterized. For example, if the tiled problem shape - // (in CTA_M x CTA_N representation) when using 1x1 clusters is 4x16, - // and we rasterize along the M dimension, we choose 4 groups, rather than 8. - // If the cluster shape is 2x1, we choose 2 groups (CTA_M / CLUSTER_M). uint32_t max_groups_problem; if (underlying_params.raster_order_ == RasterOrder::AlongM) { max_groups_problem = problem_blocks_m / cluster_shape.m(); @@ -691,14 +635,16 @@ struct PersistentTileSchedulerSm90StreamKParams { // number of K tiles per stream-K unit remains above min_iters_per_sk_unit_ uint32_t groups = platform::min(max_groups_problem, uint32_t(max_sk_groups_)); - - // Grouping is disabled when separate reduction is used - if (do_separate_reduction) { + // Grouping is disabled when separate reduction is used because grouping is primarily an attempt + // to improve L2 locality, and L2-locality optimizations are unnecessary when the the kernel + // is a single wave (which is the case for separate reduction). + if ( + do_separate_reduction + ) { groups = 1; } uint32_t fallback_groups = 0; - auto sk_cluster_tiles = sk_tiles / cluster_size; auto sk_cluster_units = sk_units / cluster_size; auto sk_splits_too_small = [&](uint32_t g) { @@ -737,82 +683,281 @@ struct PersistentTileSchedulerSm90StreamKParams { if (groups == 1 && fallback_groups > 0) { groups = fallback_groups; } + return groups; + } - auto sk_units_per_group = sk_units / groups; + // Stream-K kernel use below function to set stream-K feature related parameters to choose + // optimal/customized decomposition mode. + void stream_k_heuristic( + UnderlyingParams underlying_params, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord cluster_shape, + KernelHardwareInfo hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + const uint32_t epilogue_subtile = 1 + ) { + uint32_t groups = 0; + uint32_t sk_tiles = 0; + uint64_t sk_units = 0; + uint64_t cluster_size = 0; + uint64_t dp_units = 0; + uint64_t k_tiles_per_group = 0; + uint64_t k_tiles_per_sk_unit = 0; + uint64_t sk_big_groups = 0; + uint32_t sk_splits = 1; + // Self calculated optimal heuristic mode + DecompositionMode heuristic_mode = + select_decomposition_mode( + groups, + sk_tiles, + sk_units, + cluster_size, + dp_units, + k_tiles_per_group, + k_tiles_per_sk_unit, + sk_big_groups, + sk_splits, + underlying_params, + problem_blocks, + k_tiles_per_output_tile, + cluster_shape, + hw_info, + splits, + max_swizzle, + raster_order_option, + decomposition_mode, + reduction_mode, + epilogue_subtile + ); - // sk_tiles is guaranteed to be divisible by cluster_size because it is calculated as: - // sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count)) - // Both total_tiles and sm_count are multiples of cluster size due to padding added - // prior to kernel launch. - uint64_t sk_cluster_tiles_per_group = sk_cluster_tiles / groups; - uint64_t sk_tiles_per_group = sk_cluster_tiles_per_group * cluster_size; + // Given heuristic_mode returned from the heuristic() method, set params fields. + // Here, we decouple the params that have no relation with + // decomposition mode from the params that are decided within heuristic(). + set_params( + heuristic_mode, + groups, + sk_tiles, + sk_units, + cluster_size, + dp_units, + k_tiles_per_group, + k_tiles_per_sk_unit, + sk_big_groups, + sk_splits, + underlying_params, + problem_blocks, + k_tiles_per_output_tile, + cluster_shape, + splits, + epilogue_subtile, + reduction_mode); + } - // Groups that will process an extra stream-K tile cluster. These differ from "big_units," which - // are stream-K units within a group that process an extra K chunk. - uint64_t sk_big_groups = sk_cluster_tiles % groups; + // Return the optimal decomposition result by heuristic. + DecompositionMode select_decomposition_mode( + uint32_t &groups, + uint32_t &sk_tiles, + uint64_t &sk_units, + uint64_t &cluster_size, + uint64_t &dp_units, + uint64_t &k_tiles_per_group, + uint64_t &k_tiles_per_sk_unit, + uint64_t &sk_big_groups, + uint32_t &sk_splits, + UnderlyingParams underlying_params, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord cluster_shape, + KernelHardwareInfo hw_info, + int splits, + int max_swizzle, + RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, + ReductionMode reduction_mode, + uint32_t epilogue_subtile + ) { - uint64_t k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; - - // Number of k tiles computed per stream-K unit - uint64_t k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group; - - uint32_t reduction_units = 0; - - // Use separate reduction when we have less than one wave of output tiles (dp_tiles == 0) - // and when each tile will be operated on by at least two stream-K units (sk_units > 2 * sk_tiles) - if (do_separate_reduction) { - // Each reduction unit will reduce the partials of an epilogue subtile for - // a given output tile and compute the epilogue. Thus, there are as many reduction - // units as there are epilogue subtiles. - reduction_units = sk_tiles * epilogue_subtile; + // Get block numbers in m, n and l dimensions + if (decomposition_mode == DecompositionMode::SplitK || + (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { + // Short circuit to basic split-K decomposition + uint32_t adapted_splits = adjust_split_count( + splits, hw_info.sm_count, k_tiles_per_output_tile + ); + sk_splits = adapted_splits; + return DecompositionMode::SplitK; } - else if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) { - // If the number of stream-K units is a multiple of the number of stream-K tiles, then - // the problem can leverage a basic split-K decomposition for the stream-K tiles. - // This case happens when separate reduction is disable. - uint32_t sk_splits = static_cast(sk_units / sk_tiles); + else { + // Calculate the maximum number of blocks from clusters of shape cluster_shape that we + // can fit within sm_count SMs. + // Get block numbers in m, n and l dimensions + auto problem_blocks_l = problem_blocks.z; + auto problem_blocks_m = round_up(problem_blocks.x, (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); + uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; + dim3 grid = get_grid_shape( + problem_blocks, + cluster_shape, + hw_info, + max_swizzle, + raster_order_option + ); + uint64_t ctas_per_wave = grid.x * grid.y; + cluster_size = cluster_shape.m() * cluster_shape.n(); + // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. + sk_tiles = get_num_sk_tiles( + output_tiles, + ctas_per_wave, + cluster_size, + k_tiles_per_output_tile, + decomposition_mode + ); + uint64_t dp_tiles = output_tiles - sk_tiles; + // Calculate the number of work units covering the data-parallel and stream-K tiles. + // A "work unit" is a single index in the linearized ID space used by the scheduler. + // We distinguish it from a "block," which is typically tied to a hardware unit + // (e.g., the callers into this scheduler will be persistent thread blocks). + // A work unit can encompass multiple output tiles worth of work (as will be the + // case for stream-K blocks). + // Since splitting is not required for data-parallel tiles, only one data-parallel unit + // is needed per data-parallel tile. + dp_units = dp_tiles; + + uint64_t ctas_per_sk_wave = ctas_per_wave; + sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); + + if (decomposition_mode == DecompositionMode::DataParallel || + (decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) || + sk_units == 0) { + // Short circuit to basic data-parallel decomposition + return DecompositionMode::DataParallel; + } + else { + bool do_separate_reduction = should_perform_separate_reduction( + epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave); + + uint64_t sk_cluster_tiles = sk_tiles / cluster_size; + + groups = calculate_groups(underlying_params, reduction_mode, problem_blocks_m, problem_blocks_n, cluster_shape, + cluster_size, sk_tiles, sk_cluster_tiles, sk_units, k_tiles_per_output_tile, do_separate_reduction); + + auto sk_units_per_group = sk_units / groups; + + // sk_tiles is guaranteed to be divisible by cluster_size because it is calculated as: + // sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count)) + // Both total_tiles and sm_count are multiples of cluster size due to padding added + // prior to kernel launch. + uint64_t sk_cluster_tiles_per_group = sk_cluster_tiles / groups; + uint64_t sk_tiles_per_group = sk_cluster_tiles_per_group * cluster_size; + + // Groups that will process an extra stream-K tile cluster. These differ from "big_units," which + // are stream-K units within a group that process an extra K chunk. + sk_big_groups = sk_cluster_tiles % groups; + + k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; + + // Number of k tiles computed per stream-K unit + k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group; + + DecompositionMode heuristic_mode; + if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) { + // If the number of stream-K units is a multiple of the number of stream-K tiles, then + // the problem can leverage a basic split-K decomposition for the stream-K tiles. + // This case happens when separate reduction is disable. + sk_splits = static_cast(sk_units / sk_tiles); + heuristic_mode = DecompositionMode::SplitK; + } + else { + // Rest scenario is streamk + heuristic_mode = DecompositionMode::StreamK; + } + // Refresh heuristic_mode using analytical model before choosing streamk/separate_reduction decomposition, + // ideally it's to get the final decomposition more accuracy. Comment it as it is place holder at this moment. + #if 0 + uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); + analytical_model(heuristic_mode, k_tiles_per_output_tile, k_tiles_per_sk_unit, + sk_splits, epilogue_subtile, total_waves); + #endif + return heuristic_mode; + } + } + } + + // Given decomposition mode output from heuristic, set all feilds of params. + void set_params( + DecompositionMode heuristic_mode, + uint32_t groups, + uint32_t sk_tiles, + uint64_t sk_units, + uint64_t cluster_size, + uint64_t dp_units, + uint64_t k_tiles_per_group, + uint64_t k_tiles_per_sk_unit, + uint64_t sk_big_groups, + uint32_t sk_splits, + UnderlyingParams underlying_params, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord cluster_shape, + uint32_t splits, + uint32_t epilogue_subtile, + ReductionMode reduction_mode) { + // The highest priority when customers set as splitk mode, may set + // with a adpated splits value rather than the original splits + // even it does not make sense + if (splits > 1 && heuristic_mode == DecompositionMode::SplitK) { set_params_basic( underlying_params, - problem_blocks_m, - problem_blocks_n, - problem_blocks_l, - sk_splits, + problem_blocks, + cluster_shape, + sk_splits, // split-k set by customers k_tiles_per_output_tile, - reduction_workspace, reduction_mode ); - return; } - divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; - divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; - divmod_batch_ = underlying_params.divmod_batch_; - divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); - divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; - divmod_sk_groups_ = FastDivmodU64(static_cast(groups)); - divmod_sk_units_per_group_ = FastDivmodU64(static_cast(sk_units / groups)); - - // Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units. - // This setting ensures that the use of this divmod for stream-K decompositions - // is essentially a no-op. - divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size); - divmod_splits_ = FastDivmod(1); - log_swizzle_size_ = underlying_params.log_swizzle_size_; - units_per_problem_ = static_cast(dp_units + sk_units); - raster_order_ = underlying_params.raster_order_; - - // Assign big_units_ assuming that group count == 1. This is unused by stream-K - // when group count > 1. - big_units_ = static_cast(k_tiles_per_group % k_tiles_per_sk_unit); - - big_groups_ = static_cast(sk_big_groups); - reduction_workspace_ = reduction_workspace; - sk_tiles_ = sk_tiles; - sk_units_ = static_cast(sk_units); - divmod_k_tiles_per_sk_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit)); - divmod_k_tiles_per_sk_big_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit + 1)); - reduction_mode_ = reduction_mode; - divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile); - separate_reduction_units_ = reduction_units; + else if (heuristic_mode == DecompositionMode::DataParallel) { + set_params_basic( + underlying_params, + problem_blocks, + cluster_shape, + 1, // fast path to fall back to the mode without any split scheme + k_tiles_per_output_tile, + reduction_mode + ); + } + else if (heuristic_mode == DecompositionMode::SplitK) { + set_params_basic( + underlying_params, + problem_blocks, + cluster_shape, + sk_splits, // splits calculated by heuristic + k_tiles_per_output_tile, + reduction_mode + ); + } + else { + // streamk + set_params_stream_k( + underlying_params, + k_tiles_per_output_tile, + groups, + sk_tiles, + sk_units, + cluster_size, + dp_units, + k_tiles_per_group, + k_tiles_per_sk_unit, + sk_big_groups, + reduction_mode, + 1, /*epilogue_subtile*/ + 0 /*reduction_units*/ + ); + } } // Given the inputs, computes the physical grid we should launch. @@ -897,7 +1042,6 @@ struct PersistentTileSchedulerSm90StreamKParams { // or if there is no work to be split. return 0; } - // // The final wave is not full. Perform some stream-K work. // @@ -971,11 +1115,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t accumulator_bits, uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1) { + uint32_t num_accumulator_mtxs = 1, + uint32_t ktile_start_alignment_count = 1) { auto log_swizzle_size = UnderlyingParams::get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle); problem_blocks.x = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); @@ -989,12 +1135,6 @@ struct PersistentTileSchedulerSm90StreamKParams { barrier_workspace_size = 0; reduction_workspace_size = 0; } - else if (splits > 1 && - (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic)) { - // Basic split-K variant requires workspace for all output tiles - barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits); - reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); - } else { KernelHardwareInfo new_hw_info; new_hw_info.device_id = hw_info.device_id; @@ -1025,20 +1165,42 @@ struct PersistentTileSchedulerSm90StreamKParams { uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_sk_wave, sk_tiles, k_tiles_per_output_tile); uint64_t dp_tiles = output_tiles - sk_tiles; - uint64_t reduction_tiles = sk_tiles; - if (should_perform_separate_reduction(epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave)) { - // In separate reduction, each peer writes to its own location in scratch space. - // Thus, for separate reduction, we need as many reduction tiles per output tile - // as there are the maximum number of peers that can collaborate on an output tile. - reduction_tiles *= max_peers_per_tile(sk_units, sk_tiles); + if (decomposition_mode == DecompositionMode::SplitK || + (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { + splits = adjust_split_count( + splits, new_hw_info.sm_count, k_tiles_per_output_tile + ); } - // Though separate reduction requires a larger reduction workspace, only one barrier - // is needed per output tile. Each peer will increment the barrier by one once the peer has - // written its accumulator to scratch space. The separate reduction unit will only begin - // performing the reduction when the barrier has reached the number of peers for the output tile. - barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups, barrier_bits); - reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); + bool split_k_required = splits > 1 && (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic); + bool split_k_selected = decomposition_mode == DecompositionMode::Heuristic && + sk_units > sk_tiles && + sk_tiles != 0 && + sk_units % sk_tiles == 0; + + if (split_k_required || split_k_selected) { + // Basic split-K variant requires workspace for all output tiles + barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits); + reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); + } + else { + uint64_t reduction_tiles = sk_tiles; + if ( + should_perform_separate_reduction(epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave) + ) { + // In separate reduction, each peer writes to its own location in scratch space. + // Thus, for separate reduction, we need as many reduction tiles per output tile + // as there are the maximum number of peers that can collaborate on an output tile. + reduction_tiles *= max_peers_per_tile(sk_units, sk_tiles); + } + + // Though separate reduction requires a larger reduction workspace, only one barrier + // is needed per output tile. Each peer will increment the barrier by one once the peer has + // written its accumulator to scratch space. The separate reduction unit will only begin + // performing the reduction when the barrier has reached the number of peers for the output tile. + barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups, barrier_bits); + reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits, num_accumulator_mtxs); + } } } #endif // !defined(__CUDACC_RTC__) @@ -1063,11 +1225,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile, - uint32_t num_accumulator_mtxs) { + uint32_t num_accumulator_mtxs, + uint32_t ktile_start_alignment_count = 1) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); @@ -1082,11 +1246,13 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, - num_accumulator_mtxs + num_accumulator_mtxs, + ktile_start_alignment_count ); } @@ -1104,11 +1270,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile = 1, - uint32_t num_accumulator_mtxs = 1) { + uint32_t num_accumulator_mtxs = 1, + uint32_t ktile_start_alignment_count = 1) { size_t barrier_workspace_size = 0; size_t reduction_workspace_size = 0; @@ -1126,11 +1294,13 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, - num_accumulator_mtxs + num_accumulator_mtxs, + ktile_start_alignment_count ); #endif @@ -1151,11 +1321,13 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile, - CudaHostAdapter* cuda_adapter = nullptr) { + CudaHostAdapter* cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); @@ -1172,12 +1344,14 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, 1, - cuda_adapter + cuda_adapter, + ktile_start_alignment_count ); } @@ -1197,12 +1371,14 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, DecompositionMode decomposition_mode, + ReductionMode reduction_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, uint32_t element_accumulator_bits, uint32_t epilogue_subtile = 1, uint32_t num_accumulator_mtxs = 1, - CudaHostAdapter* cuda_adapter = nullptr) { + CudaHostAdapter* cuda_adapter = nullptr, + uint32_t ktile_start_alignment_count = 1) { #if !defined(__CUDACC_RTC__) uint64_t barrier_workspace_size = 0; @@ -1220,11 +1396,13 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, decomposition_mode, + reduction_mode, mma_warp_groups, barrier_bits, element_accumulator_bits, epilogue_subtile, - num_accumulator_mtxs + num_accumulator_mtxs, + ktile_start_alignment_count ); if (barrier_workspace_size > 0) { @@ -1242,31 +1420,41 @@ struct PersistentTileSchedulerSm90StreamKParams { return Status::kSuccess; } + // Set params for basic parameters, which will not affected by different decompositions. + void + set_params_base(UnderlyingParams const& underlying_params, void* reduction_workspace) { + divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; + divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; + divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; + log_swizzle_size_ = underlying_params.log_swizzle_size_; + raster_order_ = underlying_params.raster_order_; + reduction_workspace_ = reduction_workspace; + } + void set_params_basic( UnderlyingParams const& underlying_params, - uint32_t blocks_m, - uint32_t blocks_n, - uint32_t blocks_l, + dim3 problem_blocks, + GemmCoord cluster_shape, uint32_t splits, uint32_t k_tiles_per_output_tile, - void* reduction_workspace, ReductionMode reduction_mode) { - divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; - divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; + auto blocks_l = problem_blocks.z; + auto blocks_m = round_up(problem_blocks.x, + (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); + auto blocks_n = round_up(problem_blocks.y, + (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); + divmod_batch_ = FastDivmodU64(blocks_m * blocks_n); divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); divmod_sk_groups_ = FastDivmodU64(1u); - auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * underlying_params.divmod_cluster_shape_minor_.divisor; + auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * + underlying_params.divmod_cluster_shape_minor_.divisor; divmod_clusters_mnl_ = FastDivmodU64((blocks_m * blocks_n * blocks_l) / cluster_size); divmod_splits_ = FastDivmod(splits); - divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; - log_swizzle_size_ = underlying_params.log_swizzle_size_; units_per_problem_ = blocks_m * blocks_n * blocks_l; - raster_order_ = underlying_params.raster_order_; big_units_ = k_tiles_per_output_tile % splits; - reduction_workspace_ = reduction_workspace; reduction_mode_ = reduction_mode; divmod_k_tiles_per_sk_unit_ = FastDivmod(k_tiles_per_output_tile / splits); divmod_k_tiles_per_sk_big_unit_ = FastDivmod(k_tiles_per_output_tile / splits + 1); @@ -1278,6 +1466,55 @@ struct PersistentTileSchedulerSm90StreamKParams { separate_reduction_units_ = 0; } + // Set params for streamk(streamk, separate-reduction included) decomposition. + void + set_params_stream_k( + UnderlyingParams const& underlying_params, + uint32_t k_tiles_per_output_tile, + uint32_t groups, + uint32_t sk_tiles, + uint64_t sk_units, + uint64_t cluster_size, + uint64_t dp_units, + uint64_t k_tiles_per_group, + uint64_t k_tiles_per_sk_unit, + uint64_t sk_big_groups, + ReductionMode reduction_mode, + uint32_t epilogue_subtile, + uint32_t reduction_units) { + // stream-k and separate-reduction decompostions + divmod_batch_ = underlying_params.divmod_batch_; + divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); + divmod_sk_groups_ = FastDivmodU64(static_cast(groups)); + divmod_sk_units_per_group_ = FastDivmodU64(static_cast(sk_units / groups)); + + // Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units. + // This setting ensures that the use of this divmod for stream-K decompositions + // is essentially a no-op. + divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size); + divmod_splits_ = FastDivmod(1); + units_per_problem_ = static_cast(dp_units + sk_units); + + // Assign big_units_ assuming that group count == 1. This is unused by stream-K + // when group count > 1. + auto big_units_in_ctas = k_tiles_per_group % sk_units; + + // Store big_units in terms of clusters. big_units_in_ctas is guaranteed to be divisible + // by cluster_size because both k_tiles_per_group and k_tiles_per_sk_unit must be a multiple + // of cluster_size. + auto big_units_in_clusters = big_units_in_ctas / cluster_size; + big_units_ = static_cast(big_units_in_clusters); + + big_groups_ = static_cast(sk_big_groups); + sk_tiles_ = sk_tiles; + sk_units_ = static_cast(sk_units); + divmod_k_tiles_per_sk_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit)); + divmod_k_tiles_per_sk_big_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit + 1)); + reduction_mode_ = reduction_mode; + divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile); + separate_reduction_units_ = reduction_units; + } + private: // Round up number of bytes to the nearest multiple of L2 cache line alignment CUTLASS_HOST_DEVICE @@ -1286,8 +1523,31 @@ struct PersistentTileSchedulerSm90StreamKParams { constexpr size_t L2CacheLineSizeBytes = 128u; return (bytes + L2CacheLineSizeBytes - 1) / L2CacheLineSizeBytes * L2CacheLineSizeBytes; } + + CUTLASS_HOST_DEVICE + static int adjust_split_count( + int splits, + int sm_count, + uint32_t k_tiles_per_output_tile + ) { + // Don't split by more than the available number of SMs + if (splits > sm_count) { + splits = sm_count; + } + + // Don't split by more than the K tile iterations + if (static_cast(splits) > k_tiles_per_output_tile) { + splits = k_tiles_per_output_tile; + } + + // If k_tiles_per_output_tiles / splits == 1, there will be one k_tile per cta + // and this violate k_tile start from even requirements. Thus we need to + // reduce the number of splits. + return splits; + } }; + //////////////////////////////////////////////////////////////////////////////// // Parameters for SM90 persistent group scheduler (only used for Grouped Gemms) @@ -1453,18 +1713,7 @@ struct PersistentTileSchedulerSm90GroupParams { // GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU // Hence, maximum SMs per GPC = 18 constexpr int max_sm_per_gpc = 18; - // Provided SM count could possibly be less than the assumed maximum SMs per GPC - auto cluster_size = cluster_shape.m() * cluster_shape.n(); - int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; - int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); - int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; - - // The calculation below allows for larger grid size launch for different GPUs. - int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; - int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); - cta_per_device += max_cta_occupancy_per_residual_gpc; - - cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; + int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); if (raster_order == RasterOrder::AlongN) { launch_grid.y = possibly_truncate( diff --git a/include/cutlass/gemm/thread/mma_sm50.h b/include/cutlass/gemm/thread/mma_sm50.h index c778832b..4c70bcf3 100644 --- a/include/cutlass/gemm/thread/mma_sm50.h +++ b/include/cutlass/gemm/thread/mma_sm50.h @@ -147,7 +147,7 @@ struct MmaGeneric { CUTLASS_PRAGMA_UNROLL for (int k = 0; k < Shape::kK; ++k) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 860) - if (kMultipleOf2 && kAllFp32) { + if constexpr (kMultipleOf2 && kAllFp32) { //2x2 zigzag - m and n loops to increment by 2. Inner loop to process 4 multiply-adds in a 2x2 tile. CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Shape::kN; n+=2) { @@ -396,34 +396,36 @@ struct MmaGeneric< CUTLASS_PRAGMA_UNROLL for (int k = 0; k < Shape::kK; ++k) { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < Shape::kN; ++n) { - + { CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < Shape::kM; ++m) { + for (int n = 0; n < Shape::kN; ++n) { - int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Shape::kM; ++m) { - MatrixCoord mn(m_serpentine, n); - MatrixCoord mk(m_serpentine, k); - MatrixCoord kn(k, n); + int m_serpentine = (n % 2) ? (Shape::kM - 1 - m) : m; - Array d; - Array a; - Array b; + MatrixCoord mn(m_serpentine, n); + MatrixCoord mk(m_serpentine, k); + MatrixCoord kn(k, n); - d[0] = d_ref.at(mn); - a[0] = a_ref.at(mk); - b[0] = b_ref.at(kn); + Array d; + Array a; + Array b; - if ((m == 0 && n) || m == Shape::kM - 1) { - mma_corner(d, a, b, d); + d[0] = d_ref.at(mn); + a[0] = a_ref.at(mk); + b[0] = b_ref.at(kn); + + if ((m == 0 && n) || m == Shape::kM - 1) { + mma_corner(d, a, b, d); + } + else { + mma_column(d, a, b, d); + } + + d_ref.at(mn) = d[0]; } - else { - mma_column(d, a, b, d); - } - - d_ref.at(mn) = d[0]; } } } diff --git a/include/cutlass/gemm/threadblock/ell_mma_multistage.h b/include/cutlass/gemm/threadblock/ell_mma_multistage.h index 27f410cc..17cc9dae 100644 --- a/include/cutlass/gemm/threadblock/ell_mma_multistage.h +++ b/include/cutlass/gemm/threadblock/ell_mma_multistage.h @@ -243,12 +243,12 @@ public: if (is_offset_constant){ auto ell_offset = ell_iter.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorA::Element) / kSrcBytes; } else { int k_offset = iterator_A.get_k(); auto ell_offset = ell_iter.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; + gmem_ptr += (ell_offset * sizeof(typename IteratorA::Element)) / kSrcBytes; } } @@ -287,12 +287,12 @@ public: if (is_offset_constant){ auto ell_offset = ell_iter.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorB::Element) / kSrcBytes; } else { int k_offset = iterator_B.get_k(); auto ell_offset = ell_iter.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; + gmem_ptr += ( ell_offset * sizeof(typename IteratorB::Element)) / kSrcBytes; } } @@ -359,12 +359,12 @@ public: if (is_offset_constant){ auto ell_offset = ell_iterator.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorA::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorA::Element) / kSrcBytes; } else { int k_offset = iterator_A.get_k(); auto ell_offset = ell_iterator.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += (ell_offset * sizeof(IteratorA::Element)) / kSrcBytes; + gmem_ptr += (ell_offset * sizeof(typename IteratorA::Element)) / kSrcBytes; } } @@ -401,12 +401,12 @@ public: if (is_offset_constant){ auto ell_offset = ell_iterator.get_offset_fast(); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ell_offset * sizeof(IteratorB::Element) / kSrcBytes; + gmem_ptr += ell_offset * sizeof(typename IteratorB::Element) / kSrcBytes; } else { int k_offset = iterator_B.get_k(); auto ell_offset = ell_iterator.get_offset(k_offset); is_valid = is_valid && (ell_offset >= 0); - gmem_ptr += ( ell_offset * sizeof(IteratorB::Element)) / kSrcBytes; + gmem_ptr += ( ell_offset * sizeof(typename IteratorB::Element)) / kSrcBytes; } } diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index b84d322d..27a50fd2 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -93,7 +93,7 @@ struct integer_subbyte { [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; assert(value >= lower_bound); - assert(value < upper_bound); + assert(value <= upper_bound); } else { [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; @@ -112,7 +112,7 @@ struct integer_subbyte { [[maybe_unused]] constexpr int lower_bound = -(1 << (Bits - 1)); [[maybe_unused]] constexpr int upper_bound = (1 << (Bits - 1)) - 1; assert(value >= lower_bound); - assert(value < upper_bound); + assert(value <= upper_bound); } else { [[maybe_unused]] constexpr unsigned upper_bound = 1u << Bits; @@ -120,6 +120,10 @@ struct integer_subbyte { } } + CUTLASS_HOST_DEVICE explicit + integer_subbyte(uint8_t value) + : integer_subbyte(static_cast(value)) {} + // Convert to the "external" integer type (int or unsigned) CUTLASS_HOST_DEVICE operator xint_t() const { diff --git a/include/cutlass/kernel_launch.h b/include/cutlass/kernel_launch.h index ca3380a2..4cd087a3 100644 --- a/include/cutlass/kernel_launch.h +++ b/include/cutlass/kernel_launch.h @@ -37,6 +37,7 @@ #include #include "cutlass/cutlass.h" #include "cutlass/trace.h" +#include "cutlass/device_kernel.h" // cutlass::device_kernel namespace cutlass { diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h index 912eb2c8..13e5ef22 100644 --- a/include/cutlass/layout/permute.h +++ b/include/cutlass/layout/permute.h @@ -38,11 +38,8 @@ computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_} as new addresses after permute op. */ #pragma once -#if defined(__CUDACC_RTC__) + #include -#else -#include "assert.h" -#endif #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 8374fe31..d296f1d0 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -39,11 +39,9 @@ defined in cutlass/tensor_ref.h. */ #pragma once -#if defined(__CUDACC_RTC__) + #include -#else -#include "assert.h" -#endif + #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm70.h b/include/cutlass/layout/tensor_op_multiplicand_sm70.h index 4691b982..b260942a 100644 --- a/include/cutlass/layout/tensor_op_multiplicand_sm70.h +++ b/include/cutlass/layout/tensor_op_multiplicand_sm70.h @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/coord.h" #include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_coord.h" // cutlass::MatrixCoord ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 9c4fb395..b62a90cc 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -95,7 +95,6 @@ struct NumericConverter { // ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) template <> struct NumericConverter { @@ -103,50 +102,17 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { + #if __CUDA_ARCH__ return __float2int_rn(s); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = int32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - CUTLASS_DEVICE - static result_type convert(source_type const & s) { - - return __float2int_rz(s); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#elif !defined(__CUDACC_RTC__) - -template <> -struct NumericConverter { - - using result_type = int32_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - static result_type convert(source_type const & s) { + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TONEAREST); - return (result_type)std::nearbyint(s); + return static_cast(std::nearbyint(s)); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -159,16 +125,21 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { + #if __CUDA_ARCH__ + return __float2int_rz(s); + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TOWARDZERO); return (result_type)std::nearbyint(s); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } }; -#endif ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -176,7 +147,6 @@ struct NumericConverter { // ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) template <> struct NumericConverter { @@ -184,13 +154,21 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { - + #if defined(__CUDA_ARCH__) int32_t intermediate; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); + #elif !defined(__CUDACC_RTC__) + std::fesetround(FE_TONEAREST); + int32_t intermediate = (int32_t)std::nearbyint(s); + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + return static_cast(intermediate); + #endif } CUTLASS_HOST_DEVICE @@ -206,16 +184,24 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { - + #if defined(__CUDA_ARCH__) int32_t intermediate; asm volatile("cvt.rzi.sat.s8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); + #elif !defined(__CUDACC_RTC__) + std::fesetround(FE_TOWARDZERO); + int32_t intermediate = (int32_t)std::nearbyint(s); + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + return static_cast(intermediate); + #endif } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } @@ -228,13 +214,21 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { - + #if defined(__CUDA_ARCH__) int32_t intermediate; asm volatile("cvt.rni.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); + #elif !defined(__CUDACC_RTC__) + std::fesetround(FE_TONEAREST); + int32_t intermediate = (int32_t)std::nearbyint(s); + // Low-end saturation + intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); + // High-end saturation + intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); + return static_cast(intermediate); + #endif } CUTLASS_HOST_DEVICE @@ -250,125 +244,29 @@ struct NumericConverter { using source_type = float; static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { - + #if __CUDA_ARCH__ int32_t intermediate; asm volatile("cvt.rzi.sat.u8.f32 %0, %1;" : "=r"(intermediate) : "f"(s)); - return static_cast(intermediate); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#elif !defined(__CUDACC_RTC__) - -template <> -struct NumericConverter { - - using result_type = int8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - static result_type convert(source_type const & s) { - std::fesetround(FE_TONEAREST); - int32_t intermediate = (int32_t)std::nearbyint(s); - - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - - return static_cast(intermediate); - } - - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = int8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - static result_type convert(source_type const & s) { + #elif !defined(__CUDACC_RTC__) std::fesetround(FE_TOWARDZERO); int32_t intermediate = (int32_t)std::nearbyint(s); - - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - - return static_cast(intermediate); - } - - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -template <> -struct NumericConverter { - - using result_type = uint8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest; - - static result_type convert(source_type const & s) { - std::fesetround(FE_TONEAREST); - int32_t intermediate = (int32_t)std::nearbyint(s); - // Low-end saturation intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - // High-end saturation intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - return static_cast(intermediate); + #endif } + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } }; -template <> -struct NumericConverter { - - using result_type = uint8_t; - using source_type = float; - static FloatRoundStyle const round_style = FloatRoundStyle::round_toward_zero; - - static result_type convert(source_type const & s) { - std::fesetround(FE_TOWARDZERO); - int32_t intermediate = (int32_t)std::nearbyint(s); - - // Low-end saturation - intermediate = std::max(intermediate, (int32_t)std::numeric_limits::lowest()); - - // High-end saturation - intermediate = std::min(intermediate, (int32_t)std::numeric_limits::max()); - - return static_cast(intermediate); - } - - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -#endif - ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for float => integer_subbyte @@ -3281,88 +3179,88 @@ namespace detail { ///////////////////////////////////////////////////////////////////////////////////////////////// -#if defined(__CUDA_ARCH__) -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const & source) { - - unsigned const& storage = reinterpret_cast(source); - unsigned out[2]; - - asm volatile( - "{\n" - " .reg .u32 tmp0, tmp1, tmp2;\n" - " shl.b32 tmp0, %2, 4;\n" // tmp0 = x1x2x3x4x5x6x7__ - " and.b32 tmp0, tmp0, 0xf0f0f0f0;\n" // tmp0 = x1__x3__x5__x7__ - " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s1s3s5s7 - " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s1__s3__s5__s7__ - " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x1__x3__x5__x7 - " or.b32 tmp2, tmp0, tmp1;\n" // tmp2 = y1y3y5y7 - " and.b32 tmp0, %2, 0xf0f0f0f0;\n" // tmp0 = x0__x2__x4__x6__ - " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s0s2s4s6 - " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s0__s2__s4__s6__ - " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x0__x2__x4__x6 - " or.b32 tmp0, tmp0, tmp1;\n" // tmp0 = y0y2y4y6 - " prmt.b32 %0, tmp2, tmp0, 0x5140;\n" // %0 = y0y1y2y3 - " prmt.b32 %1, tmp2, tmp0, 0x7362;\n" // %1 = y4y5y6y7 - "}\n" - : "=r"(out[0]), "=r"(out[1]) - : "r"(storage)); - - return reinterpret_cast(out); - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - /// Partial specialization for Array <= Array template < int N, FloatRoundStyle Round > struct NumericArrayConverter { - static_assert(!(N % 8), "N must be multiple of 8."); + + static_assert(N % 8 == 0, "N must be a multiple of 8"); using result_type = Array; using source_type = Array; static FloatRoundStyle const round_style = Round; - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { + + #if defined(__CUDA_ARCH__) - NumericArrayConverter convert_vector_; + if constexpr ( N == 8 ) { + + unsigned const& storage = reinterpret_cast(source); + unsigned out[2]; - result_type result; + asm volatile( + "{\n" + " .reg .u32 tmp0, tmp1, tmp2;\n" + " shl.b32 tmp0, %2, 4;\n" // tmp0 = x1x2x3x4x5x6x7__ + " and.b32 tmp0, tmp0, 0xf0f0f0f0;\n" // tmp0 = x1__x3__x5__x7__ + " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s1s3s5s7 + " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s1__s3__s5__s7__ + " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x1__x3__x5__x7 + " or.b32 tmp2, tmp0, tmp1;\n" // tmp2 = y1y3y5y7 + " and.b32 tmp0, %2, 0xf0f0f0f0;\n" // tmp0 = x0__x2__x4__x6__ + " prmt.b32 tmp1, tmp0, tmp0, 0xba98;\n" // tmp1 = s0s2s4s6 + " and.b32 tmp1, tmp1, 0xf0f0f0f0;\n" // tmp1 = s0__s2__s4__s6__ + " shr.u32 tmp0, tmp0, 4;\n" // tmp0 = __x0__x2__x4__x6 + " or.b32 tmp0, tmp0, tmp1;\n" // tmp0 = y0y2y4y6 + " prmt.b32 %0, tmp2, tmp0, 0x5140;\n" // %0 = y0y1y2y3 + " prmt.b32 %1, tmp2, tmp0, 0x7362;\n" // %1 = y4y5y6y7 + "}\n" + : "=r"(out[0]), "=r"(out[1]) + : "r"(storage)); - Array *result_ptr = reinterpret_cast *>(&result); - Array const *source_ptr = reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 8; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); + return reinterpret_cast(out); + + } else { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; } - + + #else + + result_type result; + NumericConverter convert_; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = convert_(source[i]); + } + return result; + + #endif // __CUDA_ARCH__ } - CUTLASS_DEVICE + CUTLASS_HOST_DEVICE result_type operator()(source_type const &s) const { return convert(s); } }; -#endif // defined(__CUDA_ARCH__) /// Partial specialization for Array <= Array template diff --git a/include/cutlass/numeric_size.h b/include/cutlass/numeric_size.h index 4ff83bab..98fd77c3 100644 --- a/include/cutlass/numeric_size.h +++ b/include/cutlass/numeric_size.h @@ -68,6 +68,15 @@ bits_to_bytes(T bits) { return (R(bits) + R(7)) / R(8); } +/// Returns the number of bits required to hold a specified number of bytes +template +CUTLASS_HOST_DEVICE +constexpr +R +bytes_to_bits(T bytes) { + return R(bytes) * R(8); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// template diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index 5519fbe7..ca37896b 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -34,8 +34,6 @@ */ #pragma once -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" #include "cutlass/numeric_size.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 96bb8db7..b1d04f51 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -51,6 +51,65 @@ namespace cutlass { using namespace cute; +namespace detail { + +// Helper function for DEBUG checks +template +CUTLASS_DEVICE +bool pipeline_is_producer(ThreadCategory role) { + return (role == ThreadCategory::Producer || role == ThreadCategory::ProducerConsumer); +} + +template +CUTLASS_DEVICE +void pipeline_check_is_producer(ThreadCategory role) { + #ifndef NDEBUG + if (!pipeline_is_producer(role)) { + asm volatile ("brkpt;\n" ::); + } + #endif +} + +template +CUTLASS_DEVICE +bool pipeline_is_consumer(ThreadCategory role) { + return (role == ThreadCategory::Consumer || role == ThreadCategory::ProducerConsumer); +} + +template +CUTLASS_DEVICE +void pipeline_check_is_consumer(ThreadCategory role) { + #ifndef NDEBUG + if (!pipeline_is_consumer(role)) { + asm volatile ("brkpt;\n" ::); + } + #endif +} + +CUTLASS_DEVICE +cute::tuple spread_arrivals_to_warp(int thread_idx_in_warp) { + constexpr uint32_t MaxClusterSize = 16; + bool is_signaling_thread = (thread_idx_in_warp % (32 / MaxClusterSize)) == 0; + auto layout = Layout,Stride<_4, _1>>{}; + uint32_t thread_row = thread_idx_in_warp / 8; + uint32_t thread_col = (thread_idx_in_warp % 8) / 2; + uint32_t dst_blockid = layout(thread_row, thread_col); + return cute::make_tuple(is_signaling_thread, dst_blockid); +} + +CUTLASS_DEVICE +cute::tuple spread_arrivals_to_warpgroup(int thread_idx_in_warpgroup, int warp_idx) { + constexpr uint32_t MaxClusterSize = 16; + bool is_signaling_thread = (thread_idx_in_warpgroup % (NumThreadsPerWarpGroup / MaxClusterSize)) == 0; + auto layout = cute::composition(Swizzle<2,0,-2>{}, + Layout,Stride<_4,_1>>{}); + uint32_t thread_row = warp_idx % 4; + uint32_t thread_col = (thread_idx_in_warpgroup / 8) % 4; + uint32_t dst_blockid = layout(thread_row, thread_col); + return cute::make_tuple(is_signaling_thread, dst_blockid); +} +} // namespace detail + enum class BarrierStatus : uint32_t { WaitAgain = 0u, WaitDone = 1u, @@ -210,7 +269,7 @@ PipelineState make_producer_start_state() { // Currently, it is optional to elect a leader for the Consumers template class PipelineTmaAsync { -public : +public: using FullBarrier = cutlass::arch::ClusterTransactionBarrier; using EmptyBarrier = cutlass::arch::ClusterBarrier; using ProducerBarrierType = FullBarrier::ValueType; @@ -237,68 +296,92 @@ public : uint32_t num_consumers = 0; }; - // Constructor - template + template + static CUTLASS_DEVICE - PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape) + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1 + if (cute::size(cluster_shape) > 1) { + multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * + num_consumer_warpgroups_per_cluster; + } + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + } + + template + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) : params_(params) , full_barrier_ptr_(&storage.full_barrier_[0]) , empty_barrier_ptr_(&storage.empty_barrier_[0]) { int warp_idx = canonical_warp_idx_sync(); + int thread_idx = threadIdx.x; int lane_predicate = cute::elect_one_sync(); - if (warp_idx == 0 && lane_predicate == 1) { - // Barrier FULL init - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(1); + static_assert(cute::is_same_v || cute::is_same_v); + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + if constexpr (cute::is_same_v) { + // Logic to optimally schedule Empty Arrives + // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) + dim3 block_id = cute::block_id_in_cluster(); + auto cluster_size = cute::size(cluster_shape); + + if (cluster_size == 1) { + is_signaling_thread_ = true; + dst_blockid_ = 0; } - uint32_t const num_consumer_warpgroups_per_cluster = params_.num_consumers / NumThreadsPerWarpGroup; - uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * - num_consumer_warpgroups_per_cluster; - // Barrier EMPTY init - for (int i = 0; i < Stages; ++i) { - empty_barrier_ptr_[i].init(multicast_consumer_arrival_count); + else { + // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) + if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { + auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warpgroup(thread_idx % NumThreadsPerWarpGroup, warp_idx); + is_signaling_thread_ = is_signaling_thread; + dst_blockid_ = dst_blockid; + } + else if (params_.num_consumers == 32) { + auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warp(thread_idx % 32); + is_signaling_thread_ = is_signaling_thread; + dst_blockid_ = dst_blockid; + } + else { + is_signaling_thread_ = 0; + #ifndef NDEBUG + asm volatile ("brkpt;\n" ::); + #endif + } + + // STEP 2: Find if this dst block-id needs an arrival for this problem + is_signaling_thread_ &= dst_blockid_ < cluster_size; + is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); } } - cutlass::arch::fence_barrier_init(); - - // Logic to optimally schedule Empty Arrives - // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) - dim3 block_id = cute::block_id_in_cluster(); - auto cluster_size = cute::size(cluster_shape); - static constexpr int MaxClusterSize = 16; - - // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) - if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { - int thread_idx = threadIdx.x % NumThreadsPerWarpGroup; - is_signalling_thread_ = (thread_idx % (NumThreadsPerWarpGroup / MaxClusterSize)) == 0; - auto layout = cute::composition(Swizzle<2,0,-2>{}, - Layout,Stride<_4,_1>>{}); - uint32_t thread_row = warp_idx % 4; - uint32_t thread_col = (thread_idx / 8) % 4; - dst_blockid_ = layout(thread_row, thread_col); - } - else if (params_.num_consumers == 32) { - int thread_idx = threadIdx.x % 32; - is_signalling_thread_ = (thread_idx % (32 / MaxClusterSize)) == 0; - auto layout = Layout,Stride<_4, _1>>{}; - uint32_t thread_row = thread_idx / 8; - uint32_t thread_col = (thread_idx % 8) / 2; - dst_blockid_ = layout(thread_row, thread_col); - } - else { - is_signalling_thread_ = 0; - #ifndef NDEBUG - asm volatile ("brkpt;\n" ::); - #endif - } - - // STEP 2: Find if this dst block-id needs an arrival for this problem - is_signalling_thread_ &= dst_blockid_ < cluster_size; - is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); } + // Constructor + template + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape) + : PipelineTmaAsync(storage, params, cluster_shape, cute::true_type{}, cute::true_type{}) { } + + template + CUTLASS_DEVICE + PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}) + : PipelineTmaAsync(storage, params, cluster_shape, InitBarriers{}, cute::true_type{}) { } + template CUTLASS_DEVICE bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { @@ -347,6 +430,7 @@ public : // This should be called once before kernel exits. CUTLASS_DEVICE void producer_tail(PipelineState state) { + detail::pipeline_check_is_producer(params_.role); for (int count = 0; count < Stages; ++count) { empty_barrier_ptr_[state.index()].wait(state.phase()); ++state; @@ -386,15 +470,16 @@ public : consumer_release(state.index()); } -private : +private: uint32_t dst_blockid_ = 0; - uint32_t is_signalling_thread_ = 0; + uint32_t is_signaling_thread_ = 0; FullBarrier *full_barrier_ptr_ = nullptr; EmptyBarrier *empty_barrier_ptr_ = nullptr; Params params_; CUTLASS_DEVICE ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_producer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -404,6 +489,7 @@ private : CUTLASS_DEVICE void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); if (barrier_token != BarrierStatus::WaitDone) { empty_barrier_ptr_[stage].wait(phase); } @@ -454,6 +540,7 @@ private : CUTLASS_DEVICE ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -463,6 +550,7 @@ private : CUTLASS_DEVICE ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -473,12 +561,14 @@ private : // Wait for producer to commit transactions (done by TMA) CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase) { + detail::pipeline_check_is_consumer(params_.role); full_barrier_ptr_[stage].wait(phase); } // Wait for producer to commit transactions (done by TMA) CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + detail::pipeline_check_is_consumer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { full_barrier_ptr_[stage].wait(phase); } @@ -488,7 +578,8 @@ private : // Ensures all blocks in the Same Row and Column get notifed. CUTLASS_DEVICE void consumer_release(uint32_t stage, uint32_t skip = false) { - empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signalling_thread_ & (!skip)); + detail::pipeline_check_is_consumer(params_.role); + empty_barrier_ptr_[stage].arrive(dst_blockid_, is_signaling_thread_ & (!skip)); #ifndef NDEBUG if (params_.role == ThreadCategory::Producer || params_.role == ThreadCategory::NonParticipant) { asm volatile ("brkpt;\n" ::); @@ -625,7 +716,7 @@ private: /////////////////////////////////////////////////////////////////////////////////////////////////// template class PipelineTransactionAsync { -public : +public: using FullBarrier = cutlass::arch::ClusterTransactionBarrier; using EmptyBarrier = cutlass::arch::ClusterBarrier; using ProducerBarrierType = FullBarrier::ValueType; @@ -653,26 +744,45 @@ public : uint32_t dst_blockid = cute::block_rank_in_cluster(); }; - // Constructor + static CUTLASS_DEVICE - PipelineTransactionAsync(SharedStorage& storage, Params const& params) + void + init_barriers(SharedStorage& storage, Params const& params) { + FullBarrier *full_barrier_ptr = storage.full_barrier_.data(); + EmptyBarrier *empty_barrier_ptr = storage.empty_barrier_.data(); + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + full_barrier_ptr, empty_barrier_ptr, params.producer_arv_count, params.consumer_arv_count); + } + } + + // Constructor + template + CUTLASS_DEVICE + PipelineTransactionAsync(SharedStorage& storage, Params const& params, InitBarriers = cute::true_type{}) : params_(params) , full_barrier_ptr_(storage.full_barrier_.data()) , empty_barrier_ptr_(storage.empty_barrier_.data()) { + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); - // Barrier FULL, EMPTY init - // Init is done only by thread 0 of the block - if (warp_idx == 0 && lane_predicate) { - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(params.producer_arv_count); - empty_barrier_ptr_[i].init(params.consumer_arv_count); - } + static_assert(cute::is_same_v || cute::is_same_v); + + if constexpr (cute::is_same_v) { + init_barriers(storage, params); } - cutlass::arch::fence_barrier_init(); + } + // Constructor + CUTLASS_DEVICE + PipelineTransactionAsync(SharedStorage& storage, Params const& params) : + PipelineTransactionAsync(storage, params, cute::true_type{}) { } + //////////////////// // Producer APIs //////////////////// @@ -758,6 +868,7 @@ private: CUTLASS_DEVICE ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_producer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -767,6 +878,7 @@ private: CUTLASS_DEVICE void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { empty_barrier_ptr_[stage].wait(phase); } @@ -775,11 +887,13 @@ private: // Perform an expect-tx operation on the stage's full barrier. Must be called by 1 thread CUTLASS_DEVICE void producer_expect_transaction(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); full_barrier_ptr_[stage].expect_transaction(params_.transaction_bytes); } CUTLASS_DEVICE void producer_commit(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); full_barrier_ptr_[stage].arrive(params_.dst_blockid); } @@ -790,6 +904,7 @@ private: CUTLASS_DEVICE ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -799,6 +914,7 @@ private: CUTLASS_DEVICE ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -808,6 +924,7 @@ private: CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + detail::pipeline_check_is_consumer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { full_barrier_ptr_[stage].wait(phase); } @@ -815,6 +932,7 @@ private: CUTLASS_DEVICE void consumer_release(uint32_t stage, uint32_t skip = false) { + detail::pipeline_check_is_consumer(params_.role); empty_barrier_ptr_[stage].arrive(params_.dst_blockid, (not skip)); } }; @@ -841,7 +959,7 @@ namespace PipelineDetail { template class PipelineAsync { -public : +public: static constexpr uint32_t Stages = Stages_; using SharedStorage = PipelineDetail::PipelineAsyncSharedStorage; using FullBarrier = typename SharedStorage::FullBarrier; @@ -864,33 +982,46 @@ public : uint32_t dst_blockid = cute::block_rank_in_cluster(); }; - // Default assumption when only storage is passed is : - // => single producer, single consumer & they are in the same block (within the Cluster) + static CUTLASS_DEVICE - PipelineAsync(SharedStorage& storage) - : PipelineAsync(storage, {}) {} + void + init_barriers(SharedStorage& storage, Params params) { + int warp_idx = canonical_warp_idx_sync(); + bool is_initializing_warp = (warp_idx == 0); + if (is_initializing_warp) { + // Barrier FULL and EMPTY init + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); + } + } + + template + CUTLASS_DEVICE + PipelineAsync( + SharedStorage& storage, + Params const& params, + InitBarriers = {}) : + params_(params), + full_barrier_ptr_(&storage.full_barrier_[0]), + empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_); + } + } CUTLASS_DEVICE PipelineAsync( SharedStorage& storage, Params const& params) : - params_(params), - full_barrier_ptr_(&storage.full_barrier_[0]), - empty_barrier_ptr_(&storage.empty_barrier_[0]) { + PipelineAsync(storage, params, cute::true_type{}) { } - int warp_idx = canonical_warp_idx_sync(); - int lane_predicate = cute::elect_one_sync(); - - // Barrier FULL, EMPTY init - // Init is done only by thread 0 of the block - if (warp_idx == 0 && lane_predicate == 1) { - for (int i = 0; i < Stages; ++i) { - full_barrier_ptr_[i].init(params.producer_arv_count); - empty_barrier_ptr_[i].init(params.consumer_arv_count); - } - } - cutlass::arch::fence_barrier_init(); - } + // Default assumption when only storage is passed is : + // => single producer, single consumer & they are in the same block (within the Cluster) + CUTLASS_DEVICE + PipelineAsync(SharedStorage& storage) + : PipelineAsync(storage, {}, cute::true_type{}) {} //////////////////// // Producer APIs @@ -983,6 +1114,7 @@ private: CUTLASS_DEVICE ProducerToken producer_try_acquire(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_producer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -992,6 +1124,7 @@ private: CUTLASS_DEVICE void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { empty_barrier_ptr_[stage].wait(phase); } @@ -999,11 +1132,13 @@ private: CUTLASS_DEVICE void producer_commit(uint32_t stage) { + detail::pipeline_check_is_producer(params_.role); full_barrier_ptr_[stage].arrive(); } CUTLASS_DEVICE ConsumerToken consumer_try_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -1013,6 +1148,7 @@ private: CUTLASS_DEVICE ConsumerToken consumer_test_wait(uint32_t stage, uint32_t phase, uint32_t skip_wait) { + detail::pipeline_check_is_consumer(params_.role); if (skip_wait) { return {BarrierStatus::WaitDone}; } @@ -1022,6 +1158,7 @@ private: CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase) { + detail::pipeline_check_is_consumer(params_.role); bool done = full_barrier_ptr_[stage].test_wait(phase); if (!done) { full_barrier_ptr_[stage].wait(phase); @@ -1030,6 +1167,7 @@ private: CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase, ConsumerToken barrier_token) { + detail::pipeline_check_is_consumer(params_.role); if (barrier_token == BarrierStatus::WaitAgain) { full_barrier_ptr_[stage].wait(phase); } @@ -1037,6 +1175,7 @@ private: CUTLASS_DEVICE void consumer_release(uint32_t stage) { + detail::pipeline_check_is_consumer(params_.role); empty_barrier_ptr_[stage].arrive(params_.dst_blockid); } }; @@ -1075,7 +1214,7 @@ public: uint32_t group_size; }; -private : +private: // In future this Params object can be replaced easily with a CG object Params params_; Barrier *barrier_ptr_; @@ -1110,7 +1249,6 @@ public: } } } - cutlass::arch::fence_barrier_init(); } // Wait on a stage to be unlocked diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index ba1f7401..13e018db 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -106,7 +106,11 @@ #include #include #else -#include +#include +#include +#include +#include +#include #endif #if !defined(__CUDACC_RTC__) @@ -134,6 +138,10 @@ #define CUTLASS_OS_WINDOWS #endif +#if defined(__clang__) && defined(__CUDA__) +#define CUTLASS_CLANG_CUDA 1 +#endif + /****************************************************************************** * Macros ******************************************************************************/ @@ -298,30 +306,13 @@ namespace platform { #if defined(__CUDACC_RTC__) || (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500)) -/// std::integral_constant -template -struct integral_constant; - -/// std::integral_constant -template -struct integral_constant { - static const value_t value = V; - - typedef value_t value_type; - typedef integral_constant type; - - CUTLASS_HOST_DEVICE operator value_type() const { return value; } - - CUTLASS_HOST_DEVICE const value_type operator()() const { return value; } -}; - #else -using std::integral_constant; using std::pair; #endif +using CUTLASS_STL_NAMESPACE::integral_constant; using CUTLASS_STL_NAMESPACE::bool_constant; using CUTLASS_STL_NAMESPACE::true_type; using CUTLASS_STL_NAMESPACE::false_type; diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index aa4e3f1a..e8781562 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -35,15 +35,14 @@ #pragma once #if defined(__CUDACC_RTC__) -#include #include #else -#include -#include +#include #endif -#include "cutlass/cutlass.h" +#include +#include "cutlass/cutlass.h" #include "cutlass/platform/platform.h" namespace cutlass { diff --git a/include/cutlass/real.h b/include/cutlass/real.h index e53301b3..95a22444 100644 --- a/include/cutlass/real.h +++ b/include/cutlass/real.h @@ -35,6 +35,8 @@ #pragma once +#include // CUTLASS_DEVICE + namespace cutlass { /// Used to determine the real-valued underlying type of a numeric type T. diff --git a/include/cutlass/reduction/thread/reduction_operators.h b/include/cutlass/reduction/thread/reduction_operators.h index ba62c1b5..8423c2d9 100644 --- a/include/cutlass/reduction/thread/reduction_operators.h +++ b/include/cutlass/reduction/thread/reduction_operators.h @@ -172,7 +172,7 @@ struct ReduceArrayOperation, uint1b_t, N> { item = (item || !bits); } - return uint1b_t(!item); + return uint1b_t{!item}; } }; @@ -195,7 +195,7 @@ struct ReduceArrayOperation, uint1b_t, N> { item = (item || bits); } - return uint1b_t(item); + return uint1b_t{item}; } }; diff --git a/include/cutlass/tensor_view_planar_complex.h b/include/cutlass/tensor_view_planar_complex.h index c98de563..af63f80c 100644 --- a/include/cutlass/tensor_view_planar_complex.h +++ b/include/cutlass/tensor_view_planar_complex.h @@ -48,6 +48,7 @@ #include "cutlass/cutlass.h" #include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view.h" // cutlass::TensorView namespace cutlass { diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 8e7ab884..d6d265a4 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -40,6 +40,7 @@ #include #include #include +#include // std::memcpy #endif #include "cutlass/cutlass.h" diff --git a/include/cutlass/transform/device/transform_universal_adapter.hpp b/include/cutlass/transform/device/transform_universal_adapter.hpp index c7ab0ceb..a5033d80 100644 --- a/include/cutlass/transform/device/transform_universal_adapter.hpp +++ b/include/cutlass/transform/device/transform_universal_adapter.hpp @@ -59,7 +59,7 @@ template class TransformUniversalAdapter { public: - using TransformKernel = TransformKernel_; + using TransformKernel = GetUnderlyingKernel_t; using Arguments = typename TransformKernel::Arguments; using Params = typename TransformKernel::Params; static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; diff --git a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp index 0ae7bab0..dd4fa0c1 100644 --- a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +++ b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp @@ -267,7 +267,9 @@ private: case 3: return 0b11; default: + CUTLASS_ASSERT(false); CUTE_GCC_UNREACHABLE; + return 0b00; } }; diff --git a/media/docs/cute/04_algorithms.md b/media/docs/cute/04_algorithms.md index 353e9cc4..f427e5ef 100644 --- a/media/docs/cute/04_algorithms.md +++ b/media/docs/cute/04_algorithms.md @@ -100,7 +100,7 @@ void copy(Tensor const& src, // Any logical shape Tensor & dst) // Any logical shape { - for (int i = 0; i < size(src); ++i) { + for (int i = 0; i < size(dst); ++i) { dst(i) = src(i); } } diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 97ed6a63..29e5a0f6 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -185,7 +185,6 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS=90a # compiles for NVIDIA Hopper GP ``` **NVIDIA Ampere Architecture.** - ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA Ampere GPU architecture ``` diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index f9723c46..fad27837 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -57,6 +57,19 @@ CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path) # Alias CUTLASS_PATH as source_path source_path = CUTLASS_PATH +_NVCC_VERSION = None +def nvcc_version(): + global _NVCC_VERSION + if _NVCC_VERSION is None: + import subprocess + + # Attempt to get NVCC version + result = subprocess.run(['nvcc', '--version'], capture_output=True) + if result.returncode != 0: + raise Exception('Unable to run `nvcc --version') + _NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0] + return _NVCC_VERSION + _CUDA_INSTALL_PATH = None def cuda_install_path(): """ diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index d72af78e..95e264cd 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -139,7 +139,7 @@ def get_tile_scheduler_arguments_3x( splits: int = 1): max_swizzle_size = 1 raster_order_option = 0 # Heuristic - if tile_scheduler == TileSchedulerType.Persistent: + if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]: return _PersistentTileSchedulerArguments( max_swizzle_size, raster_order_option, diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index f52b1818..2c38397d 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -90,7 +90,7 @@ class CompilationOptions: opts.append(f"--include-path={incl}") arch_flag = f"-arch=sm_{self.arch}" - if self.arch == 90: + if self.arch == 90 and int(cutlass.nvcc_version().split('.')[0]) >= 12: arch_flag += "a" opts.append(arch_flag) @@ -237,7 +237,7 @@ class ArtifactManager: if incl not in includes: includes.append(incl) - includes_host = ["builtin_types.h", "device_launch_parameters.h", "stddef.h"] + includes + includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes for incl in includes: source_buffer_device += SubstituteTemplate( IncludeTemplate, diff --git a/python/cutlass/backend/epilogue.py b/python/cutlass/backend/epilogue.py index e0b5e957..48366a76 100644 --- a/python/cutlass/backend/epilogue.py +++ b/python/cutlass/backend/epilogue.py @@ -44,6 +44,7 @@ from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torc dtype2ctype = { DataType.f16: ctypes.c_uint16, + DataType.bf16: ctypes.c_uint16, DataType.f32: ctypes.c_float, DataType.f64: ctypes.c_double, DataType.s8: ctypes.c_int8, diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass/epilogue/evt_ops.py index 575767d0..153b937e 100644 --- a/python/cutlass/epilogue/evt_ops.py +++ b/python/cutlass/epilogue/evt_ops.py @@ -59,18 +59,21 @@ def max(x, dim): elif is_torch_tensor(x): return torch.amax(x, dim) + def maximum(x, y): if is_numpy_tensor(x): return np.maximum(x, y) elif is_torch_tensor(x): return torch.maximum(x, torch.tensor(y)) - + + def minimum(x, y): if is_numpy_tensor(x): return np.minimum(x, y) elif is_torch_tensor(x): return torch.minimum(x, torch.tensor(y)) + ############################################################################## # Layout manipulate nodes ############################################################################## diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 7c16cc68..2a02f61c 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -51,6 +51,20 @@ _generator_ccs = [50, 60, 61, 70, 75, 80, 90] # Strip any additional information from the CUDA version _cuda_version = __version__.split("rc")[0] +# Check that Python CUDA version exceeds NVCC version +_nvcc_version = cutlass.nvcc_version() +_cuda_list = _cuda_version.split('.') +_nvcc_list = _cuda_version.split('.') +for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list): + if int(val_cuda) < int(val_nvcc): + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}") + +if len(_nvcc_list) > len(_cuda_list): + if len(_nvcc_list) != len(_cuda_list) + 1: + raise Exception(f"Malformatted NVCC version of {_nvcc_version}") + if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0: + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {_nvcc_version}") + class KernelsForDataType: """ @@ -278,7 +292,7 @@ class ArchOptions: ] manifest_args = cutlass_library.generator.define_parser().parse_args(args) manifest = cutlass_library.manifest.Manifest(manifest_args) - generate_function(manifest, _cuda_version) + generate_function(manifest, _nvcc_version) if operation_kind not in manifest.operations: # No kernels generated for this architecture, this could be because the CUDA diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 6f8483ed..62a5474a 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -818,6 +818,7 @@ ${compile_guard_end} element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized values = { 'operation_name': operation.procedural_name(), 'operation_suffix': self.operation_suffix, diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index e6a9f9e8..bd06a801 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -177,7 +177,7 @@ def CreateGemmUniversal3xOperator( complex_transforms=None, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity1, - tile_schedulers=[TileSchedulerType.Persistent]): + tile_schedulers=[TileSchedulerType.Default]): if type(data_types) is dict: data_types = [data_types] @@ -226,7 +226,7 @@ def CreateSparseGemmUniversal3xOperator( complex_transforms=None, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity1, - tile_schedulers=[TileSchedulerType.Persistent]): + tile_schedulers=[TileSchedulerType.Default]): if type(data_types) is dict: data_types = [data_types] @@ -1048,7 +1048,7 @@ def CreateConvOperator3x(manifest: Manifest, schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \ [(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)], complex_transforms: Optional[Sequence[ComplexTransform]] = None, - tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Persistent], + tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default], conv_kind: ConvKind = ConvKind.Fprop, log_indent_level: int = 1): """ @@ -6508,6 +6508,7 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): data_type, alignment_constraints, BlasMode.hermitian) # + ################################################################################################### def GenerateSM90_Conv3x(manifest, cuda_version, @@ -6703,6 +6704,7 @@ def GenerateSM90_Conv3x(manifest, cuda_version, product( ( ConvKind.Dgrad, + ConvKind.Wgrad ), spatial_dims, ( diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index be9eef20..3ccfb403 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -75,6 +75,7 @@ class DataType(enum.Enum): u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() + s2 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() @@ -92,11 +93,13 @@ class DataType(enum.Enum): cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() + cs2 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() + cu2 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() @@ -126,6 +129,7 @@ DataTypeNames = { DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", + DataType.s2: "s2", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", @@ -143,11 +147,13 @@ DataTypeNames = { DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", + DataType.cu2: "cu2", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", + DataType.cs2: "cs2", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", @@ -164,6 +170,7 @@ DataTypeTag = { DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", + DataType.s2: "cutlass::int2b_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", @@ -181,11 +188,13 @@ DataTypeTag = { DataType.cf32: "cutlass::complex", DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", + DataType.cu2: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", DataType.cu16: "cutlass::complex", DataType.cu32: "cutlass::complex", DataType.cu64: "cutlass::complex", + DataType.cs2: "cutlass::complex", DataType.cs4: "cutlass::complex", DataType.cs8: "cutlass::complex", DataType.cs16: "cutlass::complex", @@ -202,6 +211,7 @@ DataTypeSize = { DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, + DataType.s2: 2, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, @@ -219,11 +229,13 @@ DataTypeSize = { DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, + DataType.cu2: 4, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, + DataType.cs2: 4, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 08fcd547..021406d7 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -492,6 +492,21 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if not (is_fp8 and is_sparse): schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue]) stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 0): + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + # Inconsistency: fp8 pingpong only gets stamped out with fast accum + if not is_fp8 or level >= 1: + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpong, + EpilogueScheduleType.TmaWarpSpecialized + ]) + if can_do_fp8_fast_accum: + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecialized + ]) if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # Pruning: don't stamp out fp8 ping-ponging kernel with non-tma epilogue @@ -526,17 +541,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, # persistent kernels with TMA epilogues if can_do_tma_epilogue: assert not requires_transposed_epilogue - # Inconsistency: fp8 pingpong only gets stamped out with fast accum - if not is_fp8 or level >= 1: - schedules.append([ - KernelScheduleType.TmaWarpSpecializedPingpong, - EpilogueScheduleType.TmaWarpSpecialized - ]) - if can_do_fp8_fast_accum: - schedules.append([ - KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecialized - ]) if can_do_cooperative: # Sparse kernels only support FastAccum FP8 mainloop if not (is_fp8 and is_sparse): diff --git a/test/python/cutlass/evt/evt_compute_sm80_90.py b/test/python/cutlass/evt/evt_compute_sm80_90.py index 3f9996cf..da6c1dec 100644 --- a/test/python/cutlass/evt/evt_compute_sm80_90.py +++ b/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -118,6 +118,5 @@ class TestEVTCompute(EVTTestCaseBase): result_keys = ["D"] launcher.verify((m, n, k), input_keys, result_keys, l) - if __name__ == '__main__': unittest.main() diff --git a/test/self_contained_includes/CMakeLists.txt b/test/self_contained_includes/CMakeLists.txt index a9070746..a576868b 100644 --- a/test/self_contained_includes/CMakeLists.txt +++ b/test/self_contained_includes/CMakeLists.txt @@ -131,6 +131,115 @@ set(header_files_to_check cute/atom/mma_traits_sm80.hpp cute/atom/mma_traits_sm90.hpp cute/atom/mma_traits_sm90_gmma.hpp + # cutlass + cutlass/aligned_buffer.h + cutlass/array.h + cutlass/array_planar_complex.h + cutlass/array_subbyte.h + cutlass/barrier.h + cutlass/bfloat16.h + cutlass/blas3.h + cutlass/blas3_types.h + cutlass/block_striped.h + cutlass/cluster_launch.hpp + cutlass/complex.h + cutlass/constants.h + cutlass/coord.h + cutlass/core_io.h + cutlass/cuda_host_adapter.hpp + cutlass/cutlass.h + cutlass/device_kernel.h + cutlass/fast_math.h + cutlass/float8.h + # cutlass/floating_point_nvrtc.h + cutlass/functional.h + cutlass/gemm_coord.h + cutlass/gemm_coord.hpp + cutlass/half.h + cutlass/integer_subbyte.h + cutlass/kernel_hardware_info.h + cutlass/kernel_hardware_info.hpp + cutlass/kernel_launch.h + cutlass/matrix.h + cutlass/matrix_coord.h + cutlass/matrix_shape.h + cutlass/numeric_conversion.h + cutlass/numeric_size.h + cutlass/numeric_types.h + cutlass/pitch_linear_coord.h + cutlass/predicate.h + cutlass/predicate_vector.h + cutlass/quaternion.h + cutlass/real.h + cutlass/relatively_equal.h + cutlass/semaphore.h + cutlass/subbyte_reference.h + cutlass/tensor_coord.h + cutlass/tensor_ref.h + cutlass/tensor_ref_planar_complex.h + cutlass/tensor_view.h + cutlass/tensor_view_planar_complex.h + cutlass/tfloat32.h + cutlass/trace.h + cutlass/uint128.h + cutlass/version.h + cutlass/wmma_array.h + cutlass/workspace.h + # cutlass/platform + cutlass/platform/platform.h + + # cutlass/pipeline + cutlass/pipeline/pipeline.hpp + cutlass/pipeline/sm90_pipeline.hpp + # cutlass/detail + cutlass/detail/cluster.hpp + cutlass/detail/collective.hpp + cutlass/detail/dependent_false.hpp + cutlass/detail/helper_macros.hpp + cutlass/detail/layout.hpp + cutlass/detail/mainloop_fusion_helper_bgrada.hpp + cutlass/detail/mma.hpp + # cutlass/arch + cutlass/arch/arch.h + cutlass/arch/barrier.h + cutlass/arch/cache_operation.h + cutlass/arch/config.h + cutlass/arch/custom_abi.h + cutlass/arch/grid_dependency_control.h + cutlass/arch/memory.h + # cutlass/arch/memory_sm75.h + # cutlass/arch/memory_sm80.h + cutlass/arch/mma.h + # cutlass/arch/mma_sm50.h + # cutlass/arch/mma_sm60.h + # cutlass/arch/mma_sm61.h + # cutlass/arch/mma_sm70.h + # cutlass/arch/mma_sm75.h + # cutlass/arch/mma_sm80.h + # cutlass/arch/mma_sm89.h + # cutlass/arch/mma_sm90.h + cutlass/arch/mma_sparse_sm80.h + cutlass/arch/mma_sparse_sm89.h + # cutlass/arch/simd.h + # cutlass/arch/simd_sm60.h + # cutlass/arch/simd_sm61.h + cutlass/arch/reg_reconfig.h + cutlass/arch/tma_operation.h + cutlass/arch/wmma.h + # cutlass/arch/wmma_sm70.h + # cutlass/arch/wmma_sm72.h + # cutlass/arch/wmma_sm75.h + # cutlass/arch/wmma_sm80.h + # cutlass/layout + cutlass/layout/layout.h + cutlass/layout/matrix.h + cutlass/layout/permute.h + cutlass/layout/pitch_linear.h + cutlass/layout/tensor.h + cutlass/layout/tensor_op_multiplicand_sm70.h + cutlass/layout/tensor_op_multiplicand_sm75.h + cutlass/layout/tensor_op_multiplicand_sm80.h + cutlass/layout/vector.h ) # for each header in _header_files: diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 7c4864f6..b02ec65a 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -63,7 +63,7 @@ set(CUTLASS_TEST_UNIT_RESULTS_CACHE_DIR ${CMAKE_CURRENT_LIST_DIR}/data/hashes) function(cutlass_test_unit_add_executable NAME) - set(options WITHOUT_CUDA) + set(options WITHOUT_CUDA DO_NOT_LOWERCASE_TEST_NAME) set(oneValueArgs) set(multiValueArgs TEST_SETS_SUPPORTED EXTRA_INCLUDE_DIRS) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -109,14 +109,22 @@ function(cutlass_test_unit_add_executable NAME) set(CUTLASS_TEST_UNIT_TEST_COMMAND_OPTIONS --gtest_output=xml:${NAME_STEM}.gtest.xml) + if (__DO_NOT_LOWERCASE_TEST_NAME) + set(DO_NOT_LOWERCASE_TEST_NAME DO_NOT_LOWERCASE_TEST_NAME) + else() + set(DO_NOT_LOWERCASE_TEST_NAME) + endif() + cutlass_add_executable_tests( ${NAME_STEM} ${NAME} TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED} TEST_COMMAND_OPTIONS CUTLASS_TEST_UNIT_TEST_COMMAND_OPTIONS ${RESULT_CACHE_FILE_ARGS} + ${DO_NOT_LOWERCASE_TEST_NAME} ) endfunction() + add_custom_target(cutlass_test_unit) add_custom_target(test_unit) diff --git a/test/unit/common/filter_architecture.cpp b/test/unit/common/filter_architecture.cpp index 5171eb5c..32acad1e 100644 --- a/test/unit/common/filter_architecture.cpp +++ b/test/unit/common/filter_architecture.cpp @@ -87,7 +87,6 @@ void FilterArchitecture() { << " [" << cudaGetErrorString(err) << "]" << std::endl; exit(1); } - cudaDeviceProp deviceProperties; err = cudaGetDeviceProperties(&deviceProperties, cudaDeviceId); if (cudaSuccess != err) { diff --git a/test/unit/conv/device_3x/conv_problem_sizes.hpp b/test/unit/conv/device_3x/conv_problem_sizes.hpp index ef651712..d66de64a 100644 --- a/test/unit/conv/device_3x/conv_problem_sizes.hpp +++ b/test/unit/conv/device_3x/conv_problem_sizes.hpp @@ -1159,6 +1159,37 @@ std::vector> get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { using ProblemShape = cutlass::conv::ConvProblemShape; std::vector problem_shapes; + // Test TMA truncation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 512, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {2}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1024, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 2048, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {8}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); // non-packed input/output strides. // stride divides dilation // asymmetric padding diff --git a/test/unit/conv/device_3x/testbed_conv.hpp b/test/unit/conv/device_3x/testbed_conv.hpp index 3227f3d6..b392165c 100644 --- a/test/unit/conv/device_3x/testbed_conv.hpp +++ b/test/unit/conv/device_3x/testbed_conv.hpp @@ -336,10 +336,17 @@ struct ConvTestbed { // Scale if constexpr (cute::is_same_v> || - cute::is_same_v>) { + cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> ) { fusion_args.activation.scale = ElementCompute{1}; } + // LeakyRelu + if constexpr (cute::is_same_v> ) { + fusion_args.activation.leaky_alpha = ElementCompute{0}; + } + cutlass::Status status = cutlass::Status::kInvalid; status = conv_op.can_implement(args); @@ -617,8 +624,9 @@ bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f for (DecompositionMode decomp_mode : decomposition_modes) { std::vector problem_splits = {Splits{1}}; if constexpr (UsesStreamKScheduler) { - if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + if (decomp_mode == DecompositionMode::SplitK) { problem_splits.push_back(Splits{2}); + problem_splits.push_back(Splits{4}); } } for (auto splits : problem_splits) { diff --git a/test/unit/core/float8.cu b/test/unit/core/float8.cu index 6fd04485..14d9d22b 100644 --- a/test/unit/core/float8.cu +++ b/test/unit/core/float8.cu @@ -35,6 +35,7 @@ #include "../common/cutlass_unit_test.h" #include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" #include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/cute/ampere/CMakeLists.txt b/test/unit/cute/ampere/CMakeLists.txt index 6ac7f2f2..c1a654e8 100644 --- a/test/unit/cute/ampere/CMakeLists.txt +++ b/test/unit/cute/ampere/CMakeLists.txt @@ -28,7 +28,7 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_ampere - cp_async.cu + cp_sync.cu ldsm.cu cooperative_gemm.cu cooperative_copy.cu diff --git a/test/unit/cute/ampere/cooperative_copy.cu b/test/unit/cute/ampere/cooperative_copy.cu index a91000cb..fef61aa2 100644 --- a/test/unit/cute/ampere/cooperative_copy.cu +++ b/test/unit/cute/ampere/cooperative_copy.cu @@ -46,6 +46,7 @@ #include // cute::Swizzle #include // cute::compose(cute::Swizzle) #include +#include using namespace cute; @@ -71,7 +72,7 @@ cooperative_copy_default_gs(T const* g_in, T* g_out, GMemLayout const& gmem_layo Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), gmem_layout); Tensor s_tensor = make_tensor(make_smem_ptr(smem), smem_layout); - cooperative_copy(threadIdx.x, g_in_tensor, s_tensor); + cooperative_copy(threadIdx.x, g_in_tensor, s_tensor, AutoCopyAsync{}); cp_async_fence(); cp_async_wait<0>(); @@ -84,7 +85,7 @@ cooperative_copy_default_gs(T const* g_in, T* g_out, GMemLayout const& gmem_layo } __syncthreads(); - cooperative_copy(threadIdx.x, s_tensor, g_out_tensor); + cooperative_copy(threadIdx.x, s_tensor, g_out_tensor, AutoCopyAsync{}); } // ss --> shared to shared @@ -106,7 +107,7 @@ cooperative_copy_default_ss(T const* g_in, T* g_out, Layout1 const& layout1, Lay Tensor s1_tensor = make_tensor(make_smem_ptr(smem1), layout2); Tensor s2_tensor = make_tensor(make_smem_ptr(smem2), layout1); - cooperative_copy>(threadIdx.x, g_in_tensor, s1_tensor); + cooperative_copy>(threadIdx.x, g_in_tensor, s1_tensor, AutoCopyAsync{}); cp_async_fence(); cp_async_wait<0>(); @@ -119,10 +120,10 @@ cooperative_copy_default_ss(T const* g_in, T* g_out, Layout1 const& layout1, Lay } __syncthreads(); - cooperative_copy(threadIdx.x, s1_tensor, s2_tensor); + cooperative_copy(threadIdx.x, s1_tensor, s2_tensor, AutoCopyAsync{}); __syncthreads(); - cooperative_copy>(threadIdx.x, s2_tensor, g_out_tensor); + cooperative_copy>(threadIdx.x, s2_tensor, g_out_tensor, AutoCopyAsync{}); } // gg --> global to global @@ -135,7 +136,7 @@ cooperative_copy_default_gg(T const* g_in, T* g_out, Layout1 const& layout1, Lay Tensor g_in_tensor = make_tensor(make_gmem_ptr(g_in), layout1); Tensor g_out_tensor = make_tensor(make_gmem_ptr(g_out), layout2); - cooperative_copy(threadIdx.x, g_in_tensor, g_out_tensor); + cooperative_copy(threadIdx.x, g_in_tensor, g_out_tensor, AutoCopyAsync{}); } template @@ -252,7 +253,7 @@ typedef testing::Types< std::tuple>, std::tuple>, std::tuple>, - std::tuple>, + std::tuple> > CooperativeCopyModeMaxVecBitsList; TYPED_TEST_SUITE(SM80_CuTe_Ampere, CooperativeCopyModeMaxVecBitsList); diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu index 2ba01933..5bb6ecd2 100644 --- a/test/unit/cute/ampere/cooperative_gemm.cu +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -40,406 +40,462 @@ using namespace cute; TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA) { + constexpr uint32_t thread_block_size = 128; using value_type = cutlass::half_t; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA) { + constexpr uint32_t thread_block_size = 128; using value_type = double; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm3_Half_MMA_CustomSmemLayouts) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using value_type = cutlass::half_t; - constexpr uint32_t m = 128; - constexpr uint32_t n = 128; - constexpr uint32_t k = 128; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_128, _128, _128>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` - >; + >{}; - using smem_a_atom_layout_t = Layout, Stride< _1,_64>>; - using smem_b_atom_layout_t = Layout, Stride<_32, _1>>; - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + auto smem_a_atom_layout = Layout, Stride< _1,_64>>{}; + auto smem_b_atom_layout = Layout, Stride<_32, _1>>{}; + auto smem_c_atom_layout = make_layout(select<0,1>(shape_mnk)); - test_cooperative_gemm_col_major_layout(); + value_type> + (smem_a_atom_layout, + smem_b_atom_layout, + smem_c_atom_layout, + shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm4_Half_MMA_SwizzledSmemLayouts) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using value_type = cutlass::half_t; - constexpr uint32_t m = 128; - constexpr uint32_t n = 128; - constexpr uint32_t k = 128; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_128, _128, _128>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` - >; + >{}; // RowMajor - using smem_rowmajor_atom_layout_t = decltype( + auto smem_a_atom_layout = composition(Swizzle<3,3,3>{}, Layout, - Stride<_64, _1>>{})); + Stride<_64, _1>>{}); // ColMajor - using smem_colmajor_atom_layout_t = decltype( + auto smem_b_atom_layout = composition(Swizzle<3,3,3>{}, Layout, - Stride< _1,_64>>{})); - using smem_a_atom_layout_t = smem_rowmajor_atom_layout_t; - using smem_b_atom_layout_t = smem_colmajor_atom_layout_t; - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + Stride< _1,_64>>{}); - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); - using smem_a_atom_layout_t = smem_a_atom_layout_t; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) - ); + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{}); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); - using smem_b_atom_layout_t = smem_b_atom_layout_t; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); + auto smem_a_layout = tile_to_shape( + smem_a_atom_layout, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); - using smem_c_atom_layout_t = smem_c_atom_layout_t; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); + auto smem_b_layout = tile_to_shape( + smem_b_atom_layout, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); - test_cooperative_gemm, // C - thread_block_size, - tiled_mma_t, - 128, + auto smem_c_layout = tile_to_shape( + smem_c_atom_layout, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + cute::identity{}, // TransformLoadA + cute::identity{}, // TransformLoadB + cute::identity{}, // TransformLoadC + cute::identity{}, // TransformStoreC + SM75_U32x4_LDSM_N{}, // A + SM75_U16x8_LDSM_T{}, // B + AutoVectorizingCopyWithAssumedAlignment<128>{}); // C } TEST(SM80_CuTe_Ampere, CooperativeGemm5_Double_MMA_SwizzledSmemLayouts) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using value_type = double; - constexpr uint32_t m = 128; - constexpr uint32_t n = 64; - constexpr uint32_t k = 16; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_128, _64, _16>{}; + auto tiled_mma = TiledMMA, // Atom Layout>, // Atom layout Tile, Stride<_2, _1>>, // 32x32x4 MMA with perm for load vectorization Layout, Stride<_2, _1>>, - Underscore>>; + Underscore>>{}; - using smem_a_atom_layout_t = decltype( + auto smem_a_atom_layout = composition(Swizzle<2,2,2>{}, Layout, - Stride< _1,_16>>{})); // M, K - using smem_b_atom_layout_t = decltype( + Stride< _1,_16>>{}); // M, K + auto smem_b_atom_layout = composition(Swizzle<2,2,2>{}, Layout, - Stride< _1,_16>>{})); // N, K - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + Stride< _1,_16>>{}); // N, K - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); - using smem_a_atom_layout_t = smem_a_atom_layout_t; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) - ); - using smem_b_atom_layout_t = smem_b_atom_layout_t; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); - using smem_c_atom_layout_t = smem_c_atom_layout_t; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{}); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<128>, // B - AutoVectorizingCopyWithAssumedAlignment<128>, // C - thread_block_size, - tiled_mma_t, - 128, + auto smem_a_layout = tile_to_shape( + smem_a_atom_layout, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + auto smem_b_layout = tile_to_shape( + smem_b_atom_layout, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + auto smem_c_layout = tile_to_shape( + smem_c_atom_layout, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm6_MixedPrecisionFP16FP32_MMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::half_t; using TB = cutlass::half_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm7_MixedPrecisionBF16FP32_MMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::bfloat16_t; using TB = cutlass::bfloat16_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_MMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::tfloat32_t; using TB = cutlass::tfloat32_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } -TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA) { - +TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA_Dynamic) { + constexpr uint32_t thread_block_size = 256; + constexpr int MaxVecBits = 128; using TA = cutlass::complex; using TB = cutlass::complex; using TC = cutlass::complex; - constexpr uint32_t thread_block_size = 256; - constexpr int MaxVecBits = 128; - - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _4, _0>>, Tile - >; + >{}; - using ALayout = Layout,Int<35>>, Stride, Int<1> >>; - using BLayout = Layout, Int<35>>, Stride, Int<1> >>; - using CLayout = Layout, Int<7>>, Stride< Int<1>, Int<30>>>; + auto a_layout = make_layout(Shape,Int<35>>{}, make_stride(44, 1)); + auto b_layout = make_layout(Shape< Int<7>, Int<35>>{}, make_stride(44, 1)); + auto c_layout = make_layout(Shape, Int<7>>{}, make_stride(1, 30)); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, + test_cooperative_gemm(); + TA, TB, TC> + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); +} +TEST(SM80_CuTe_Ampere, CooperativeGemm9_C64C64C64_MMA) { + constexpr uint32_t thread_block_size = 256; + constexpr int MaxVecBits = 128; + using TA = cutlass::complex; + using TB = cutlass::complex; + using TC = cutlass::complex; + + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout, Stride<_1, _4, _0>>, + Tile + >{}; + + auto a_layout = Layout,Int<35>>, Stride, Int<1> >>{}; + auto b_layout = Layout, Int<35>>, Stride, Int<1> >>{}; + auto c_layout = Layout, Int<7>>, Stride< Int<1>, Int<30>>>{}; + + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemm10_F16F64F16_FMA) { + constexpr uint32_t thread_block_size = 256; + constexpr int MaxVecBits = 128; using TA = cutlass::half_t; using TB = double; using TC = cutlass::half_t; - constexpr uint32_t thread_block_size = 256; - constexpr int MaxVecBits = 128; - - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom>, Layout, Stride<_1, _16, _0>>, Tile - >; + >{}; - using ALayout = Layout,Int<64>>, Stride, Int< 1>>>; - using BLayout = Layout,Int<64>>, Stride, Int<64>>>; - using CLayout = Layout,Int<64>>, Stride, Int<64>>>; + auto a_layout = Layout,Int<64>>, Stride, Int< 1>>>{}; + auto b_layout = Layout,Int<64>>, Stride, Int<64>>>{}; + auto c_layout = Layout,Int<64>>, Stride, Int<64>>>{}; - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, + test_cooperative_gemm(); + TC> + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } TEST(SM80_CuTe_Ampere, CooperativeGemmComposedStride) { - using T = cute::half_t; - constexpr uint32_t thread_block_size = 128; constexpr int MaxVecBits = 16; + using T = cute::half_t; - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _2, _0>>, Tile - >; + >{}; - using swizzle = cute::Swizzle<3, 3, 3>; - using offset = cute::_0; - using atom_tile_right = decltype(cute::make_layout(cute::Shape{}, cute::LayoutRight{})); - using FP16AtomLayoutRight = decltype(cute::composition(swizzle{}, offset{}, atom_tile_right{})); + auto swizzle = cute::Swizzle<3, 3, 3>{}; + auto offset = cute::_0{}; + auto atom_tile_right = cute::make_layout(cute::Shape{}, cute::LayoutRight{}); + auto FP16AtomLayoutRight = cute::composition(swizzle, offset, atom_tile_right); - using shape = cute::Shape, cute::Int<128>>; - using global_a_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{})); - using global_b_layout = decltype(cute::make_layout(shape{}, cute::LayoutLeft{})); - using global_c_layout = decltype(cute::make_layout(shape{}, cute::LayoutRight{})); + auto shape = cute::Shape, cute::Int<128>>{}; + auto global_a_layout = cute::make_layout(shape, cute::LayoutRight{}); + auto global_b_layout = cute::make_layout(shape, cute::LayoutLeft{}); + auto global_c_layout = cute::make_layout(shape, cute::LayoutRight{}); // This is for A row major, B col major according to CUTLASS default configs - using ALayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_a_layout{})); - using BLayout = decltype(cute::tile_to_shape(FP16AtomLayoutRight{}, global_b_layout{})); - using CLayout = global_c_layout; + auto a_layout = cute::tile_to_shape(FP16AtomLayoutRight, global_a_layout); + auto b_layout = cute::tile_to_shape(FP16AtomLayoutRight, global_b_layout); + auto c_layout = global_c_layout; - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, + test_cooperative_gemm(); + T, T, T> + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } -TEST(SM89_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_Transform) { +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_Transform) { + constexpr uint32_t thread_block_size = 64; + constexpr uint32_t max_vec_bits = 16; using TA = cutlass::tfloat32_t; using TB = cutlass::tfloat32_t; using TC = float; - constexpr uint32_t m = 9; - constexpr uint32_t n = 9; - constexpr uint32_t k = 9; - - constexpr uint32_t thread_block_size = 64; - - using tiled_mma_t = + auto shape_mnk = Shape, C<9>, C<9>>{}; + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(cute::negate{}, cute::negate{}, cute::negate{}, cute::negate{}); + test_cooperative_gemm_col_major_layout + (shape_mnk, tiled_mma, cute::negate{}, cute::negate{}, cute::negate{}, cute::negate{}); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_TransformPrecision) { + constexpr uint32_t thread_block_size = 64; + constexpr uint32_t max_vec_bits = 16; + using InputTA = cutlass::half_t; + using InputTB = cutlass::half_t; + using InputTC = cutlass::half_t; + + using ComputeTA = cutlass::tfloat32_t; + using ComputeTB = cutlass::tfloat32_t; + using ComputeTC = float; + + auto shape_mnk = Shape, C<9>, C<9>>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout + (shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_TransformPrecisionReg) { + constexpr uint32_t thread_block_size = 64; + constexpr uint32_t max_vec_bits = 16; + using InputTA = cutlass::half_t; + using InputTB = cutlass::half_t; + using InputTC = cutlass::half_t; + + using ComputeTA = cutlass::tfloat32_t; + using ComputeTB = cutlass::tfloat32_t; + using ComputeTC = float; + + auto shape_mnk = Shape, C<9>, C<9>>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout_rmem_c + (shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA_Reg) { + using value_type = cutlass::half_t; + + auto shape_mnk = Shape<_64, _64, _64>{}; + + constexpr uint32_t thread_block_size = 128; + + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout_rmem_c(shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA_Reg) { + constexpr uint32_t thread_block_size = 128; + using value_type = double; + + auto shape_mnk = Shape<_64, _64, _64>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout_rmem_c(shape_mnk, tiled_mma); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA_Predicated_Reg) { + constexpr uint32_t thread_block_size = 128; + using value_type = double; + + auto shape_mnk = Shape, C<62>, C<62>>{}; + auto tiled_mma = + TiledMMA< + MMA_Atom, + Layout> + >{}; + + test_cooperative_gemm_col_major_layout_rmem_c(shape_mnk, tiled_mma); } diff --git a/test/unit/cute/ampere/cp_async.cu b/test/unit/cute/ampere/cp_sync.cu similarity index 97% rename from test/unit/cute/ampere/cp_async.cu rename to test/unit/cute/ampere/cp_sync.cu index 871e8ff9..f5045410 100644 --- a/test/unit/cute/ampere/cp_async.cu +++ b/test/unit/cute/ampere/cp_sync.cu @@ -69,14 +69,12 @@ test2(double const* g_in, double* g_out) copy(g_tensor, s_tensor); - cp_async_fence(); - cp_async_wait<0>(); __syncthreads(); g_out[threadIdx.x] = 2 * smem[threadIdx.x]; } -TEST(SM80_CuTe_Ampere, CpAsync) +TEST(SM80_CuTe_Ampere, CpSync) { constexpr int count = 32; thrust::host_vector h_in(count); diff --git a/test/unit/cute/cooperative_gemm_common.hpp b/test/unit/cute/cooperative_gemm_common.hpp index 5dec22ca..dbb85e6b 100644 --- a/test/unit/cute/cooperative_gemm_common.hpp +++ b/test/unit/cute/cooperative_gemm_common.hpp @@ -54,198 +54,76 @@ struct fp64_tester> { using value_type = complex; }; -template -__launch_bounds__(ThreadBlockSize) __global__ void -cooperative_gemm_kernel(TA const* a, - TB const* b, - TC* c, - TC* c_out, - Alpha const alpha, - Beta const beta, - ALoadTransform a_load_transform, - BLoadTransform b_load_transform, - CLoadTransform c_load_transform, - CStoreTransform c_store_transform) -{ - using namespace cute; + class ALayout, // logical shape (M, K) + class BLayout, // logical shape (N, K) + class CLayout> // logical shape (M, N) +auto host_generate_gemm_inputs( + ALayout a_layout, + BLayout b_layout, + CLayout c_layout +) { + thrust::host_vector h_a(cosize(a_layout)); + thrust::host_vector h_b(cosize(b_layout)); + thrust::host_vector h_c(cosize(c_layout)); + thrust::host_vector h_c_out(cosize(c_layout)); - Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), ALayout{}); - Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), BLayout{}); - Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), CLayout{}); - Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), CLayout{}); + auto h_a_tensor = make_tensor(h_a.data(), a_layout); + auto h_b_tensor = make_tensor(h_b.data(), b_layout); + auto h_c_tensor = make_tensor(h_c.data(), c_layout); + size_t max_size = std::max({static_cast(size(a_layout)), + static_cast(size(b_layout)), + static_cast(size(c_layout))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(a_layout)) { + h_a_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(b_layout)) { + h_b_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(c_layout)) { + h_c_tensor(i) = static_cast((di*di) / size(a_layout)); + } + } - constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; - - extern __shared__ float4 smem_buf[]; - auto* smem_ptr = reinterpret_cast(smem_buf); - auto* smem_ptr_a = smem_ptr; - auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(SMemALayout {})), copy_max_vec_bytes); - auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(SMemBLayout {})), copy_max_vec_bytes); - - Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), SMemALayout{}); - Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), SMemBLayout{}); - Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), SMemCLayout{}); - - cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); - cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); - cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); - - cp_async_fence(); - cp_async_wait<0>(); - __syncthreads(); - - TiledMma tiled_mma; - cooperative_gemm( - threadIdx.x, tiled_mma, - alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, - a_load_transform, b_load_transform, c_load_transform, c_store_transform - ); - __syncthreads(); - - cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); + return std::make_tuple(h_a, h_b, h_c, h_c_out); } -template -void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) -{ - using gmem_a_layout_t = ALayout; - using gmem_b_layout_t = BLayout; - using gmem_c_layout_t = CLayout; +thrust::host_vector +host_reference_gemm(Alpha alpha, + Tensor const& h_a_tensor, + Tensor const& h_b_tensor, + Beta beta, + Tensor const& h_c_tensor, + ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) + { + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TA = remove_cv_t; + using TB = remove_cv_t; + using TC = remove_cv_t; - using smem_a_layout_t = SMemALayout; - using smem_b_layout_t = SMemBLayout; - using smem_c_layout_t = SMemCLayout; + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); - using tester = fp64_tester; - using ABC_64 = typename tester::value_type; - static_assert(size<0>(gmem_a_layout_t{}) == size<0>(gmem_c_layout_t{})); // AM == CM - static_assert(size<0>(gmem_b_layout_t{}) == size<1>(gmem_c_layout_t{})); // BN == CN - static_assert(size<1>(gmem_a_layout_t{}) == size<1>(gmem_b_layout_t{})); // AK == BK - - static_assert(size<0>(smem_a_layout_t{}) == size<0>(smem_c_layout_t{})); // AM == CM - static_assert(size<0>(smem_b_layout_t{}) == size<1>(smem_c_layout_t{})); // BN == CN - static_assert(size<1>(smem_a_layout_t{}) == size<1>(smem_b_layout_t{})); // AK == BK - - static_assert(cute::size(gmem_a_layout_t {}) == cute::size(smem_a_layout_t {})); - static_assert(cute::size(gmem_b_layout_t {}) == cute::size(smem_b_layout_t {})); - static_assert(cute::size(gmem_c_layout_t {}) == cute::size(smem_c_layout_t {})); - -#if 0 - print(" "); print("gmem: "); print(gmem_layout_t{}); print("\n"); - print(" "); print("smem: "); print(smem_layout_t{}); print("\n"); - print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); -#endif - - const auto alpha = static_cast(1.1); - const auto beta = static_cast(1.2); - - thrust::host_vector h_a(cosize(gmem_a_layout_t{})); - thrust::host_vector h_b(cosize(gmem_b_layout_t{})); - thrust::host_vector h_c(cosize(gmem_c_layout_t{})); - thrust::host_vector h_c_out(cosize(gmem_c_layout_t{})); - - auto h_a_tensor = make_tensor(h_a.data(), gmem_a_layout_t{}); - auto h_b_tensor = make_tensor(h_b.data(), gmem_b_layout_t{}); - auto h_c_tensor = make_tensor(h_c.data(), gmem_c_layout_t{}); - size_t max_size = std::max({static_cast(size(gmem_a_layout_t {})), - static_cast(size(gmem_b_layout_t {})), - static_cast(size(gmem_c_layout_t {}))}); - for (size_t i = 0; i < max_size; ++i) { - double di = static_cast(i); - if(i < size(gmem_a_layout_t{})) { - h_a_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); - } - if(i < size(gmem_b_layout_t{})) { - h_b_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); - } - if(i < size(gmem_c_layout_t{})) { - h_c_tensor(i) = static_cast((di*di) / size(gmem_a_layout_t{})); - } - } - - thrust::device_vector d_a(h_a); - thrust::device_vector d_b(h_b); - thrust::device_vector d_c(h_c); - thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); - - constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; - const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) - + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) - + (sizeof(TC) * h_c.size()); - auto kernel = cooperative_gemm_kernel< - gmem_a_layout_t, gmem_b_layout_t, gmem_c_layout_t, - smem_a_layout_t, smem_b_layout_t, smem_c_layout_t, - SmemCopyOpA, SmemCopyOpB, SmemCopyOpC, - ThreadBlockSize, TiledMma, CopyMaxVecBits, - TA, TB, TC, decltype(alpha), decltype(beta), - ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform - >; - ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); - - kernel<<<1, ThreadBlockSize, shared_memory_size>>>( - thrust::raw_pointer_cast(d_a.data()), - thrust::raw_pointer_cast(d_b.data()), - thrust::raw_pointer_cast(d_c.data()), - thrust::raw_pointer_cast(d_c_out.data()), - alpha, - beta, - a_load_transform, - b_load_transform, - c_load_transform, - c_store_transform - ); - cudaError_t result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - cudaError_t error = cudaGetLastError(); - FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; - } - - thrust::host_vector h_c_ref(h_c.size(), static_cast(0.0)); - auto h_c_ref_tensor = make_tensor(h_c_ref.data(), gmem_c_layout_t{}); + thrust::host_vector h_c_ref(cosize(h_c_tensor.layout()), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout()); // A * B for (int k = 0; k < size<1>(h_a_tensor); k++) { for (int m = 0; m < size<0>(h_a_tensor); m++) { @@ -265,8 +143,20 @@ void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); } - h_c_out = d_c_out; - auto h_c_out_tensor = make_tensor(h_c_out.data(), gmem_c_layout_t{}); + return h_c_ref; +} + +template +void verify_gemm_correctness(cute::Tensor const& h_c_out_tensor, + cute::Tensor const& h_c_ref_tensor) +{ + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + for (int i = 0; i < size(h_c_ref_tensor); i++) { ABC_64 h_c_ref_i = h_c_ref_tensor(i); ABC_64 h_c_out_i = h_c_out_tensor(i); @@ -277,156 +167,604 @@ void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, } } -template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + Alpha const alpha, + Beta const beta, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op, + SMemCopyOpC c_copy_op) +{ + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), smem_c_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + cooperative_gemm( + threadIdx.x, tiled_mma, + alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_copy_op, b_copy_op, c_copy_op + ); + __syncthreads(); + + cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); +} + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op) + { + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Create C fragment for storing intermediate results + auto thr_mma = TiledMma().get_thread_slice(threadIdx.x); + Tensor g_c_partition = thr_mma.partition_C(g_c_tensor); + Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor); + Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition); + + // Create indexing help for predicated GEMMs + Tensor cC = make_identity_tensor(shape(gmem_c_layout)); + Tensor tCcC = thr_mma.partition_C(cC); + + // Load C from global + // (always loading in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + r_c_partition(i) = c_load_transform(g_c_partition(i)); + } + } + + cooperative_gemm( + threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition, + a_load_transform, b_load_transform, a_copy_op, b_copy_op + ); + + __syncthreads(); + + // Store C to global + // (always storing in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + g_c_out_partition(i) = c_store_transform(r_c_partition(i)); + } + } +} + +template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) + class GMemALayout, // logical shape (M, K) + class GMemBLayout, // logical shape (N, K) + class GMemCLayout, // logical shape (M, N) + class SMemALayout, // logical shape (M, K) + class SMemBLayout, // logical shape (N, K) + class SMemCLayout, // logical shape (M, N) + class TiledMma, + class ALoadTransform = cute::identity, + class BLoadTransform = cute::identity, + class CLoadTransform = cute::identity, + class CStoreTransform = cute::identity, + class ASMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}, + CSMemCopyOp c_smem_copy_op = {}) { - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); - using smem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - using smem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using smem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK - test_cooperative_gemm>, - AutoVectorizingCopyWithAssumedAlignment>, - AutoVectorizingCopyWithAssumedAlignment>, - ThreadBlockSize, - TiledMMAType, - CopyMaxVecBits, - TA, - TB, - TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); + static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM + static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.1); + const auto beta = static_cast(1.2); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + + sizeof(TC) * h_c.size(); + + + auto kernel = cooperative_gemm_kernel< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, SMemCLayout, + TA, TB, TC, decltype(alpha), decltype(beta), + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp, CSMemCopyOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + alpha, + beta, + tiled_mma, + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform, + a_smem_copy_op, + b_smem_copy_op, + c_smem_copy_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Copy result data + h_c_out = d_c_out; + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); } -template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) -{ - test_cooperative_gemm_col_major_layout, T, T, T>( - a_load_transform, b_load_transform, c_load_transform, c_store_transform); -} - -template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) + class GMemALayout, // logical shape (M, K) + class GMemBLayout, // logical shape (N, K) + class GMemCLayout, // logical shape (M, N) + class SMemALayout, // logical shape (M, K) + class SMemBLayout, // logical shape (N, K) + class TiledMma, + class ALoadTransform = cute::identity, + class BLoadTransform = cute::identity, + class CLoadTransform = cute::identity, + class CStoreTransform = cute::identity, + class ASMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}) { - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK - using smem_a_atom_layout_t = SMemAAtomLayout; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.0); + const auto beta = static_cast(1.0); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = + host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), static_cast(-1)); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes); + + + auto kernel = cooperative_gemm_kernel_rmem_c< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, + TA, TB, TC, + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + tiled_mma, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_smem_copy_op, b_smem_copy_op ); - using smem_b_atom_layout_t = SMemBAtomLayout; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } - using smem_c_atom_layout_t = SMemCAtomLayout; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); + // Copy result data + h_c_out = d_c_out; - test_cooperative_gemm>, - AutoVectorizingCopyWithAssumedAlignment>, - AutoVectorizingCopyWithAssumedAlignment>, - ThreadBlockSize, - TiledMMAType, - CopyMaxVecBits, - TA, - TB, - TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); } -template -void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, - BLoadTransform const& b_load_transform = {}, - CLoadTransform const& c_load_transform = {}, - CStoreTransform const& c_store_transform = {}) +template +void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) { - test_cooperative_gemm_col_major_layout(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma, + ops...); +} + + +template +std::enable_if_t, + cute::is_layout, + cute::is_layout>> +test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + SMemAtomLayoutC smem_atom_layout_c, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops&& ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_atom_layout_c, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + ops...); +} + + +template +void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + + test_cooperative_gemm_rmem_c + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + tiled_mma, + ops...); +} + +template +std::enable_if_t, + cute::is_layout>> +test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + test_cooperative_gemm_rmem_c + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + tiled_mma, + ops...); +} + +template +void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout_rmem_c, + T, T, T> + (static_cast(args)...); +} + +template +void test_cooperative_gemm_col_major_layout(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout, - T, - T, - T>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); + T, T, T> + (static_cast(args)...); } diff --git a/test/unit/cute/core/inverse_left.cpp b/test/unit/cute/core/inverse_left.cpp index 363cb17f..142d80fb 100644 --- a/test/unit/cute/core/inverse_left.cpp +++ b/test/unit/cute/core/inverse_left.cpp @@ -104,7 +104,7 @@ TEST(CuTe_core, Inverse_left) auto layout = Layout, Stride<_4, _1>>{}; - test_left_inverse(filter(layout)); + test_left_inverse(layout); } { diff --git a/test/unit/cute/hopper/cooperative_gemm.cu b/test/unit/cute/hopper/cooperative_gemm.cu index c4e2274d..7d992510 100644 --- a/test/unit/cute/hopper/cooperative_gemm.cu +++ b/test/unit/cute/hopper/cooperative_gemm.cu @@ -44,91 +44,74 @@ using namespace cute; #if USE_FP8 TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF8) { + constexpr uint32_t thread_block_size = 128; + constexpr int MaxVecBits = 16; using TA = uint8_t; using TB = uint8_t; using TC = uint32_t; - constexpr uint32_t thread_block_size = 128; - constexpr int MaxVecBits = 16; - - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _2, _0>>, Tile<_32, _32, _32> - >; + >{}; - using swizzle = Swizzle<2, 4, 3>; + auto swizzle = Swizzle<2, 4, 3>{}; // This is for A row major, B col major according to CUTLASS default configs - using ALayout = decltype(composition(swizzle{}, Layout, Stride<_64, _1>>{})); - using BLayout = decltype(composition(swizzle{}, Layout, Stride<_1, _64>>{})); + auto a_layout = composition(swizzle, Layout, Stride<_64, _1>>{}); + auto b_layout = composition(swizzle, Layout, Stride<_1, _64>>{}); + auto c_layout = make_layout(Shape<_64, _64>{}, LayoutLeft{}); - using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{})); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, + test_cooperative_gemm(); - + TA, TB, TC> + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma); } #else TEST(SM90_CuTe_Hopper, CooperativeGemmTilingF16) { + constexpr uint32_t thread_block_size = 64; + constexpr int max_vec_bits = 16; using TA = half_t; using TB = half_t; using TC = half_t; - constexpr uint32_t thread_block_size = 64; - constexpr int MaxVecBits = 16; - - using tiled_mma_t = + auto tiled_mma = TiledMMA< MMA_Atom, Layout, Stride<_1, _0, _0>>, Tile<_32, _32, _32> - >; - - using swizzle = Swizzle<3, 3, 3>; + >{}; // This is for A row major, B col major according to CUTLASS default configs - using ALayout = decltype(composition(swizzle{}, - Layout, Stride<_64, _1>>{})); + auto swizzle = Swizzle<3, 3, 3>{}; + auto ALayout = composition(swizzle{}, Layout, Stride<_64, _1>>{}); + auto BLayout = composition(swizzle{}, Layout, Stride<_1, _64>>{}); + auto CLayout = make_layout(Shape<_64, _64>{}, LayoutLeft{}); - using BLayout = decltype(composition(swizzle{}, - Layout, Stride<_1, _64>>{})); - - using CLayout = decltype(make_layout(Shape<_64, _64>{}, LayoutLeft{})); - - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment, // B - AutoVectorizingCopyWithAssumedAlignment, // C - thread_block_size, - tiled_mma_t, - MaxVecBits, + test_cooperative_gemm(); + TC> + (ALayout, + BLayout, + CLayout, + ALayout, + BLayout, + CLayout, + tiled_mma); } #endif diff --git a/test/unit/cute/hopper/tma_load_testbed.hpp b/test/unit/cute/hopper/tma_load_testbed.hpp index 0c8ed91d..58d19e4a 100644 --- a/test/unit/cute/hopper/tma_load_testbed.hpp +++ b/test/unit/cute/hopper/tma_load_testbed.hpp @@ -131,7 +131,7 @@ tma_test_device_cute(T const* g_in, T* g_out, for (int stage = 0; stage < size<1>(tAgA); ++stage) { // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); if (threadIdx.x == 0) { diff --git a/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp index 2fb88de5..bca37879 100644 --- a/test/unit/cute/hopper/tma_mcast_load_testbed.hpp +++ b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -146,7 +146,7 @@ tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout for (int stage = 0; stage < size<1>(tAgA); ++stage) { // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); if (elect_one_thr) { diff --git a/test/unit/cute/turing/cooperative_gemm.cu b/test/unit/cute/turing/cooperative_gemm.cu index 14ea9670..1bda5cf7 100644 --- a/test/unit/cute/turing/cooperative_gemm.cu +++ b/test/unit/cute/turing/cooperative_gemm.cu @@ -38,21 +38,19 @@ using namespace cute; TEST(SM75_CuTe_Turing, CooperativeGemm1_MixedPrecisionFP16FP32_MMA) { + + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; using TA = cutlass::half_t; using TB = cutlass::half_t; using TC = float; - constexpr uint32_t m = 64; - constexpr uint32_t n = 64; - constexpr uint32_t k = 64; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = make_shape(_64{}, _64{}, _64{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } diff --git a/test/unit/cute/volta/cooperative_gemm.cu b/test/unit/cute/volta/cooperative_gemm.cu index 157913f9..54cf4f22 100644 --- a/test/unit/cute/volta/cooperative_gemm.cu +++ b/test/unit/cute/volta/cooperative_gemm.cu @@ -40,105 +40,85 @@ using namespace cute; TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA) { - using value_type = float; - - constexpr uint32_t m = 64; - constexpr uint32_t n = 32; - constexpr uint32_t k = 16; constexpr uint32_t thread_block_size = 128; + using value_type = float; - using tiled_mma_t = + auto shape_mnk = make_shape(_64{}, _32{}, _16{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication) { - using value_type = float; - - constexpr uint32_t m = 88; - constexpr uint32_t n = 20; - constexpr uint32_t k = 12; constexpr uint32_t thread_block_size = 128; + using value_type = float; - using tiled_mma_t = + auto shape_mnk = make_shape(C<88>{}, C<20>{}, C<12>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication2) { - using value_type = float; - - constexpr uint32_t m = 88; - constexpr uint32_t n = 36; - constexpr uint32_t k = 24; constexpr uint32_t thread_block_size = 128; + using value_type = float; - using tiled_mma_t = + auto shape_mnk = make_shape(C<88>{}, C<36>{}, C<24>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication3) { + constexpr uint32_t thread_block_size = 128; using value_type = float; - constexpr uint32_t m = 67; - constexpr uint32_t n = 13; - constexpr uint32_t k = 11; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = make_shape(C<67>{}, C<13>{}, C<11>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm2_DoubleFMA) { + constexpr uint32_t thread_block_size = 128; using value_type = double; - constexpr uint32_t m = 16; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = + auto shape_mnk = make_shape(C<16>{}, C<32>{}, C<32>{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) { - using value_type = float; - - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; constexpr uint32_t thread_block_size = 256; + using value_type = float; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom< UniversalFMA >, @@ -154,228 +134,188 @@ TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) { >, Underscore > - >; + >{}; - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm4_Half_MMA) { + constexpr uint32_t thread_block_size = 128; using value_type = cutlass::half_t; - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - using smem_a_atom_layout_t = typename tiled_mma_t::AtomLayoutB_TV; - using smem_b_atom_layout_t = typename tiled_mma_t::AtomLayoutA_TV; - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + auto smem_a_atom_layout = typename decltype(tiled_mma)::AtomLayoutB_TV{}; + auto smem_b_atom_layout = typename decltype(tiled_mma)::AtomLayoutA_TV{}; + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk)); - test_cooperative_gemm_col_major_layout(); + test_cooperative_gemm_col_major_layout + (smem_a_atom_layout, + smem_b_atom_layout, + smem_c_atom_layout, + shape_mnk, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); - using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + auto smem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto smem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto smem_c_layout = make_layout(select<0, 1>(shape_mnk)); - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<128>, // B - AutoVectorizingCopyWithAssumedAlignment<128>, // C - thread_block_size, - tiled_mma_t, - 128, + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA_Predicated) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 31; - constexpr uint32_t n = 27; - constexpr uint32_t k = 17; constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 16; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(C<31>{}, C<27>{}, C<17>{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); - using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); - using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); - using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + auto smem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto smem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto smem_c_layout = make_layout(select<0, 1>(shape_mnk)); - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<16>, // B - AutoVectorizingCopyWithAssumedAlignment<16>, // C - thread_block_size, - tiled_mma_t, - 16, + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm6_Half_MAA_SwizzledSmemLayouts) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 128; - constexpr uint32_t n = 128; - constexpr uint32_t k = 64; constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_128{}, _128{}, _64{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; - using smem_a_atom_layout_t = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride<_64, _1>>{})); - using smem_b_atom_layout_t = decltype( - composition(Swizzle<3,3,3>{}, - Layout, - Stride< _1,_64>>{})); - using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + auto smem_a_atom_layout = composition(Swizzle<3,3,3>{}, Layout, Stride<_64, _1>>{}); + auto smem_b_atom_layout = composition(Swizzle<3,3,3>{}, Layout, Stride< _1,_64>>{}); + auto smem_c_atom_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); - using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); - using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); - using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk), GenRowMajor{}); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenColMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk), GenRowMajor{}); - using smem_a_atom_layout_t = smem_a_atom_layout_t; - using smem_a_layout_t = decltype(tile_to_shape( - smem_a_atom_layout_t{}, - make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) - ); + auto smem_a_layout = tile_to_shape( + smem_a_atom_layout, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); - // Transposed - using smem_b_atom_layout_t = smem_b_atom_layout_t; - using smem_b_layout_t = decltype(tile_to_shape( - smem_b_atom_layout_t{}, - make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) - ); + auto smem_b_layout = tile_to_shape( + smem_b_atom_layout, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); - using smem_c_atom_layout_t = smem_c_atom_layout_t; - using smem_c_layout_t = decltype(tile_to_shape( - smem_c_atom_layout_t{}, - make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) - ); + auto smem_c_layout = tile_to_shape( + smem_c_atom_layout, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); - test_cooperative_gemm, // A - AutoVectorizingCopyWithAssumedAlignment<128>, // B - AutoVectorizingCopyWithAssumedAlignment<128>, // C - thread_block_size, - tiled_mma_t, - 128, + test_cooperative_gemm(); + value_type> + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma); } TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_FMA) { + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 64; using TA = float; using TB = float; using TC = double; - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; auto aload = cute::negate {}; auto bload = cute::negate {}; auto cload = cute::negate {}; auto cstore = cute::negate {}; - test_cooperative_gemm_col_major_layout( - aload, bload, cload, cstore); + test_cooperative_gemm_col_major_layout( + shape_mnk, tiled_mma, aload, bload, cload, cstore); } TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_MMA) { - using value_type = cutlass::half_t; - - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; constexpr uint32_t thread_block_size = 128; + using value_type = cutlass::half_t; - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom, Layout> - >; + >{}; auto aload = cute::negate {}; auto bload = cute::negate {}; auto cload = cute::negate {}; auto cstore = cute::negate {}; - test_cooperative_gemm_col_major_layout( - aload, bload, cload, cstore); + test_cooperative_gemm_col_major_layout( + shape_mnk, tiled_mma, aload, bload, cload, cstore); } template @@ -398,26 +338,25 @@ struct convert_to { }; TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformCustomOp_FMA) { + + constexpr uint32_t thread_block_size = 128; + constexpr uint32_t max_vec_bits = 64; + using TA = float; using TB = float; using TC = double; - constexpr uint32_t m = 32; - constexpr uint32_t n = 32; - constexpr uint32_t k = 32; - - constexpr uint32_t thread_block_size = 128; - - using tiled_mma_t = TiledMMA< + auto shape_mnk = make_shape(_32{}, _32{}, _32{}); + auto tiled_mma = TiledMMA< MMA_Atom>, Layout> - >; + >{}; auto aload = increment_by_x{1.111f}; auto bload = convert_to {}; auto cload = cute::negate {}; auto cstore = cute::negate {}; - test_cooperative_gemm_col_major_layout( - aload, bload, cload, cstore); + test_cooperative_gemm_col_major_layout( + shape_mnk, tiled_mma, aload, bload, cload, cstore); } diff --git a/test/unit/cute/volta/vectorization_auto.cu b/test/unit/cute/volta/vectorization_auto.cu index b378f8b3..585abf0e 100644 --- a/test/unit/cute/volta/vectorization_auto.cu +++ b/test/unit/cute/volta/vectorization_auto.cu @@ -67,7 +67,6 @@ kernel(GmemTensor gC, RmemTiler tiler, CopyPolicy policy) // NOTE: only 1 thread, this thread produce a block of 8x8 output. The fringe will not be touched. //copy(rC, tCgC); // Enable auto-vectorization if static - //copy_vec(rC, tCgC); // Disable auto-vectorization always copy(policy, rC, tCgC); // Use a policy to establish vectorization assumptions } diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 488c6bfa..87b6e53d 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -26,52 +26,30 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -add_custom_target( - cutlass_test_unit_gemm_device - DEPENDS - cutlass_test_unit_gemm_device_simt - cutlass_test_unit_gemm_device_tensorop_sm70 - cutlass_test_unit_gemm_device_tensorop_sm75 - cutlass_test_unit_gemm_device_tensorop_f16_sm80 - cutlass_test_unit_gemm_device_tensorop_f32_sm80 - cutlass_test_unit_gemm_device_tensorop_f32_tf32_sm80 - cutlass_test_unit_gemm_device_tensorop_f64 - cutlass_test_unit_gemm_device_tensorop_s32_sm80 - cutlass_test_unit_gemm_device_wmma - cutlass_test_unit_gemm_device_tensorop_planar_complex - cutlass_test_unit_gemm_device_sparse_tensorop_sm80 - cutlass_test_unit_gemv_device - cutlass_test_unit_gemm_device_tensorop_sm90 - cutlass_test_unit_sparse_gemm_device_tensorop_sm90 - cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 -) +add_custom_target(cutlass_test_unit_gemm_device) +add_custom_target(test_unit_gemm_device) -add_custom_target( - test_unit_gemm_device - DEPENDS - test_unit_gemm_device_simt - test_unit_gemm_device_tensorop_sm70 - test_unit_gemm_device_tensorop_sm75 - test_unit_gemm_device_tensorop_f16_sm80 - test_unit_gemm_device_tensorop_f32_sm80 - test_unit_gemm_device_tensorop_f32_tf32_sm80 - test_unit_gemm_device_tensorop_f64 - test_unit_gemm_device_tensorop_s32_sm80 - test_unit_gemm_device_wmma - test_unit_gemm_device_tensorop_planar_complex - test_unit_gemm_device_sparse_tensorop_sm80 - test_unit_gemv_device - test_unit_gemm_device_tensorop_sm90 -) +################################################################################ -add_custom_target( - cutlass_test_unit_gemm_device_sm90 - DEPENDS - cutlass_test_unit_gemm_device_tensorop_sm90 - cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 -) +function(cutlass_test_unit_gemm_device_add_deps NAME) + string(REGEX REPLACE "^cutlass_" "" TEST_NAME "${NAME}") + add_dependencies(cutlass_test_unit_gemm_device ${NAME}) + add_dependencies(test_unit_gemm_device ${TEST_NAME}) +endfunction() -cutlass_test_unit_add_executable( +function(cutlass_test_unit_gemm_device_add_executable NAME) + cutlass_test_unit_add_executable(${NAME} ${ARGN} DO_NOT_LOWERCASE_TEST_NAME) + cutlass_test_unit_gemm_device_add_deps(${NAME}) +endfunction() + +function(cutlass_test_unit_gemm_device_add_executable_split_file NAME) + cutlass_test_unit_add_executable_split_file(${NAME} ${ARGN} DO_NOT_LOWERCASE_TEST_NAME) + cutlass_test_unit_gemm_device_add_deps(${NAME}) +endfunction() + +################################################################################ + +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_simt BATCH_SOURCES ON @@ -126,7 +104,9 @@ cutlass_test_unit_add_executable( gemm_splitk_simt_sm50.cu ) -cutlass_test_unit_add_executable( +list(APPEND CUTLASS_TEST_UNIT_GEMM_DEVICE_LIST cutlass_test_unit_gemm_device_simt) + +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_simt_3x BATCH_SOURCES ON @@ -139,8 +119,7 @@ cutlass_test_unit_add_executable( sm61_gemm_s8_s8_s32_simt.cu ) - -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm70 BATCH_SOURCES ON @@ -159,7 +138,7 @@ cutlass_test_unit_add_executable( gemm_splitk_tensor_op_sm70.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm75 BATCH_SOURCES ON @@ -204,7 +183,7 @@ cutlass_test_unit_add_executable( ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f16_sm80 BATCH_SOURCES ON @@ -214,7 +193,7 @@ cutlass_test_unit_add_executable( gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f32_sm80 BATCH_SOURCES ON @@ -236,7 +215,7 @@ cutlass_test_unit_add_executable( gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f32_sm80_3x sm80_gemm_s8_s8_s32_tensor_op.cu @@ -245,7 +224,7 @@ cutlass_test_unit_add_executable( ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 BATCH_SOURCES ON @@ -286,13 +265,14 @@ cutlass_test_unit_add_executable( gemm_universal_s8t_s4n_s8t_mixed_input_tensor_op_s32_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90 BATCH_SOURCES ON BATCH_SIZE 4 sm90_gemm_f16_f16_f16_tensor_op.cu + sm90_gett_f16_f16_f16_tensor_op.cu sm90_gemm_bf16_bf16_bf16_tensor_op_f32.cu sm90_gemm_s8_s8_s8_tensor_op_s32.cu sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu @@ -302,7 +282,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f8_f8_f8_tensor_op_fp32.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_stream_k sm90_gemm_stream_k_scheduler.cu @@ -311,7 +291,7 @@ cutlass_test_unit_add_executable( ) # Alignment tests -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_alignx_sm90 BATCH_SOURCES ON @@ -336,14 +316,14 @@ cutlass_test_unit_add_executable( ) # Ptr Array test -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_ptr_array sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array_pingpong.cu ) # Group Gemm test -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90_group_gemm sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm.cu sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu @@ -351,31 +331,25 @@ cutlass_test_unit_add_executable( # Sparse tests # Sparse kernels trigger an ICE in gcc 7.5 -if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) -cutlass_test_unit_add_executable( - cutlass_test_unit_sparse_gemm_device_tensorop_sm90 +if (NOT (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) + + cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_sparse_gemm_device_tensorop_sm90 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu + sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu + sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu + sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu + ) - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 - - sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu - sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu - sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu - sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu -) -else() -cutlass_test_unit_add_executable( - cutlass_test_unit_sparse_gemm_device_tensorop_sm90 - - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 -) endif() # Fused epilogue tests -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_sm90 BATCH_SOURCES ON @@ -400,7 +374,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f8_f8_bf16_tensor_op_fp32_evt.cu sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative_evt.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 BATCH_SOURCES ON @@ -412,7 +386,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90 BATCH_SOURCES ON @@ -423,7 +397,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f32_tf32_sm80 BATCH_SOURCES ON @@ -443,7 +417,7 @@ cutlass_test_unit_add_executable( sm80_gemm_f16_f16_f32_tensor_op_f32.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_f64 BATCH_SOURCES ON @@ -471,7 +445,7 @@ cutlass_test_unit_add_executable( gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_s32_sm80 BATCH_SOURCES ON @@ -493,7 +467,7 @@ cutlass_test_unit_add_executable( gemm_s4n_s4t_s4n_tensor_op_s32_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_wmma BATCH_SOURCES ON @@ -551,7 +525,7 @@ cutlass_test_unit_add_executable( gemm_f16t_f16n_f32t_singlestage_wmma_tensor_op_f32_sm70.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_planar_complex BATCH_SOURCES ON @@ -562,7 +536,7 @@ cutlass_test_unit_add_executable( gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_tensorop_sm89 BATCH_SOURCES ON @@ -574,7 +548,7 @@ cutlass_test_unit_add_executable( # gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped BATCH_SOURCES ON @@ -583,7 +557,7 @@ cutlass_test_unit_add_executable( gemm_grouped_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped_scheduler BATCH_SOURCES ON @@ -592,7 +566,7 @@ cutlass_test_unit_add_executable( gemm_grouped_scheduler_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped_rank_2k_scheduler BATCH_SOURCES ON @@ -601,7 +575,7 @@ cutlass_test_unit_add_executable( rank_2k_grouped_scheduler_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_sparse_tensorop_sm80 BATCH_SOURCES ON @@ -622,7 +596,7 @@ cutlass_test_unit_add_executable( gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemv_device BATCH_SOURCES ON @@ -631,42 +605,22 @@ cutlass_test_unit_add_executable( gemv.cu ) -if (NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUTLASS_NVCC_DEVICE_COMPILE) -add_dependencies( - cutlass_test_unit_gemm_device - cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop + cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop + + gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu + gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu + + gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu ) -add_dependencies( - test_unit_gemm_device - test_unit_gemm_device_gemm_with_fused_epilogue_tensorop - ) - -cutlass_test_unit_add_executable( - cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop - - gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu - gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu - - gemm_with_reduction_f16t_f16n_f16n_tensorop_f32_sm80.cu -) - endif() -if (NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUTLASS_NVCC_DEVICE_COMPILE) -add_dependencies( - cutlass_test_unit_gemm_device - cutlass_test_unit_gemm_device_blas3 - ) - -add_dependencies( - test_unit_gemm_device - test_unit_gemm_device_blas3 - ) - -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_blas3 BATCH_SOURCES ON @@ -833,7 +787,7 @@ cutlass_test_unit_add_executable( hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_grouped_blas3 BATCH_SOURCES ON @@ -858,13 +812,12 @@ cutlass_test_unit_add_executable( endif() -if (NOT CUDA_COMPILER MATCHES "[Cc]lang") +if (CUTLASS_NVCC_DEVICE_COMPILE) -cutlass_test_unit_add_executable( - cutlass_test_unit_gemm_device_broadcast - - gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu -) + cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_broadcast + gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu + ) endif() diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 3d5a586d..3a6cf0b2 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -39,6 +39,7 @@ #include #include #include +#include // std::lcm #include "../../common/cutlass_unit_test.h" #include "cutlass/util/host_tensor.h" @@ -55,6 +56,7 @@ #include "cutlass/complex.h" #include "cutlass/transform/device/transform_universal_adapter.hpp" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/detail/collective.hpp" #include "testbed_utils.h" @@ -151,6 +153,12 @@ struct ElementScalarType +struct IsSfdEpi : cute::false_type {}; + +template +struct IsSfdEpi> : cute::true_type {}; + // The maximum swizzle size to use // // This class, like Splits above makes it harder to confuse @@ -1140,7 +1148,6 @@ struct HostCollectiveEpilogue { static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && (cute::is_same_v || cute::is_same_v); - using Arguments = typename Gemm::GemmKernel::EpilogueArguments; /// Initialization @@ -1454,6 +1461,22 @@ struct HostCollectiveEpilogue { bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); if(!passed) { + #if 0 + auto [M, N, K, L] = problem_shape_MNKL; + auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + for(int i=0; i(ElementD(ref(i, j, l))) != static_cast((ElementD(comp(i, j, l))))) { + printf(" ref: %f comp: %f\n", i, j, l, static_cast(ElementD(ref(i, j, l))), static_cast((ElementD(comp(i, j, l))))); + } + } + } + } + #endif std::cout<<"D is incorrect"<) { + fusion_args.beta = beta.at(coord_0); + fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr + } if constexpr (IsPerRowScaleEnabled) { int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; @@ -1620,6 +1646,7 @@ struct HostCollectiveEpilogue { // example of how to set kernel activation arguments // see ActivationFunctor::Arguments in activation.h for definition // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { fusion_args.activation.scale = ElementCompute(1); } @@ -1713,6 +1740,7 @@ struct HostCollectiveEpilogue { decltype(Vbeta), ActivationFunctor, cutlass::plus + , false /*PerColumnBias_*/ > epilogue_params{}; epilogue_params.C = C; @@ -1779,6 +1807,7 @@ struct TestbedImpl { using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, HostCollectiveDefaultEpilogue, HostCollectiveEpilogue>; @@ -2004,7 +2033,7 @@ struct TestbedImpl { return false; } } - catch (std::exception const& e) { + catch ([[maybe_unused]] std::exception const& e) { CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what()); throw; } diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index b7d1c579..479102b3 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -346,7 +346,7 @@ struct HostCollectiveMainloop { stride_b_host.clear(); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); for(int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); @@ -380,7 +380,7 @@ struct HostCollectiveMainloop { Arguments to_args(ProblemShapeType problem_shapes) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); std::vector ptr_A_host(L); std::vector ptr_B_host(L); @@ -587,7 +587,7 @@ struct HostCollectiveDefaultEpilogue { stride_d_host.clear(); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); for (int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); @@ -649,7 +649,7 @@ struct HostCollectiveDefaultEpilogue { ElementScalar beta, int batch) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); tensors_D[batch].sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); @@ -678,7 +678,7 @@ struct HostCollectiveDefaultEpilogue { Arguments to_args(ProblemShapeType problem_shapes) { auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = cutlass::platform::max(problem_shapes.groups(), L); std::vector ptr_C_host(L); std::vector ptr_D_host(L); @@ -724,8 +724,8 @@ struct HostCollectiveDefaultEpilogue { // // Allocate the GEMM workspace // - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + L = std::max(problem_shapes.groups(), L); auto coord_0 = cutlass::make_Coord(0); auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), @@ -905,9 +905,8 @@ struct HostCollectiveEpilogue { references_D.clear(); stride_c_host.clear(); stride_d_host.clear(); - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = std::max(problem_shapes.groups(), L); for (int32_t i = 0; i < L; ++i) { auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); @@ -1118,7 +1117,6 @@ struct HostCollectiveEpilogue { passed &= tmp; } } - return passed; } @@ -1189,7 +1187,7 @@ struct HostCollectiveEpilogue { Arguments to_args(ProblemShapeType problem_shapes) { auto coord_0 = cutlass::make_Coord(0); auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = std::max(problem_shapes.groups(), L); std::vector ptr_C_host(L); std::vector ptr_D_host(L); @@ -1220,19 +1218,22 @@ struct HostCollectiveEpilogue { device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); } + auto device_tensors_C_ptr = cute::is_void_v ? nullptr : + reinterpret_cast(device_tensors_C.get()); + Arguments arguments; if constexpr (IsGroupGemm) { arguments = { {}, - device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() }; } else { arguments = { {}, - device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0] }; } @@ -1252,7 +1253,9 @@ struct HostCollectiveEpilogue { fusion_args.beta = beta.at(coord_0); fusion_args.alpha_ptr = alpha.device_data(); - fusion_args.beta_ptr = beta.device_data(); + // can_implement requires beta_ptr to not be set if its voidC + fusion_args.beta_ptr = cute::is_void_v ? nullptr : + beta.device_data(); if constexpr (IsScaleFactorEnabled) { fusion_args.scale_a = scale_A.at(coord_0); @@ -1316,7 +1319,8 @@ struct HostCollectiveEpilogue { // // Allocate the GEMM workspace // - auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto [M, N, K, L] = problem_shape_MNKL; auto coord_0 = cutlass::make_Coord(0); auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); @@ -1338,7 +1342,6 @@ struct HostCollectiveEpilogue { cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); - cutlass::reference::host::GettEpilogueParams< ElementScalar, ElementScalar, @@ -1518,7 +1521,7 @@ struct TestbedImpl { { using namespace cute; auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); - L = max(problem_shapes.groups(), L); + L = std::max(problem_shapes.groups(), L); bool passed = true; for (int32_t i = 0; i < L; ++i) { @@ -1760,7 +1763,7 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative cutlass::DeviceAllocation problem_sizes_device; for (int i = 0; i < batch; ++i) { - problem_sizes_host.push_back({m, n, k}); + problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)}); } problem_sizes_device.reset(problem_sizes_host.size()); diff --git a/test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu b/test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu new file mode 100644 index 00000000..4d03fc93 --- /dev/null +++ b/test/unit/gemm/device/sm90_gett_f16_f16_f16_tensor_op.cu @@ -0,0 +1,184 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/reference/device/gett.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gett_f16t_f16n_f16n_tensor_op_gmma_f16, 8x8x8x8x8x8) { + + using BatModeStrides = int; + + using RowModeStridesA = cute::Stride; + using RedModeStrides = cute::Stride; + + using ColModeStridesB = cute::Stride; + + using RowModeStridesC = cute::Stride; + using ColModeStridesC = cute::Stride; + + using StrideA = cute::Stride; + using StrideB = cute::Stride; + using StrideC = cute::Stride; + using StrideD = StrideC; + + using TileShape = Shape, Shape<_8, _8>, Shape<_8, _8>>; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, StrideA, 8, + cutlass::half_t, StrideB, 8, + cutlass::half_t, + TileShape, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + cutlass::half_t, cutlass::half_t, + cutlass::half_t, StrideC, 8, + cutlass::half_t, StrideC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GettKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + Shape, + Shape, + int>, + CollectiveOp, + CollectiveEpilogue + >; + + using Gett = cutlass::gemm::device::GemmUniversalAdapter; + + auto problem_shape = make_shape( + make_shape(32,8), + make_shape(32,4), + make_shape(32,2), + 1 + ); + + auto [M, N, K, L] = problem_shape; + + StrideA dA = make_stride(make_stride(64, 2048), make_stride(_1{}, 32), size(M) * size(K)); + StrideB dB = make_stride(make_stride(64, 2048), make_stride(_1{}, 32), size(N) * size(K)); + StrideC dC = make_stride(make_stride(_1{}, 32), make_stride(256, 8192), size(M) * size(N)); + StrideD dD = dC; + + cutlass::half_t alpha = cutlass::half_t(1.0f); + cutlass::half_t beta = cutlass::half_t(1.0f); + + thrust::host_vector A_h(size(M) * size(K) * size(L)); + thrust::host_vector B_h(size(N) * size(K) * size(L)); + thrust::host_vector C_h(size(M) * size(N) * size(L)); + thrust::host_vector D_h(size(M) * size(N) * size(L)); + thrust::host_vector D_h_ref(size(M) * size(N) * size(L)); + + for (auto& a : A_h) a = cutlass::half_t(static_cast(4 * (rand() / double(RAND_MAX) - 1))); + for (auto& b : B_h) b = cutlass::half_t(static_cast(4 * (rand() / double(RAND_MAX) - 1))); + for (auto& c : C_h) c = cutlass::half_t(static_cast(4 * (rand() / double(RAND_MAX) - 1))); + for (auto& d : D_h) d = cutlass::half_t(-1); + for (auto& d : D_h_ref) d = cutlass::half_t(-1); + + thrust::device_vector A = A_h; + thrust::device_vector B = B_h; + thrust::device_vector C = C_h; + thrust::device_vector D = D_h; + thrust::device_vector D_ref = D_h_ref; + + typename Gett::Arguments args { + cutlass::gemm::GemmUniversalMode::kBatched, + problem_shape, + {A.data().get(), dA, B.data().get(), dB}, + { {alpha, beta}, C.data().get(), dC, D.data().get(), dD} + }; + + Gett gett; + auto status = gett(args); + EXPECT_TRUE(status == cutlass::Status::kSuccess); + auto cuda_err = cudaDeviceSynchronize(); + + EXPECT_TRUE(cuda_err == cudaSuccess); + + cutlass::reference::device::gett( + problem_shape, + A.data().get(), dA, + B.data().get(), dB, + cutlass::half_t(0.0f), + C.data().get(), dC, + D_ref.data().get(), dD, + alpha, beta); + + cuda_err = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_err == cudaSuccess); + + bool passed = cutlass::reference::device::BlockCompareEqual( + D.data().get(), D_ref.data().get(), D_ref.size()); + EXPECT_TRUE(passed); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/transform/CMakeLists.txt b/test/unit/transform/CMakeLists.txt index 4912eca2..0ab0b93f 100644 --- a/test/unit/transform/CMakeLists.txt +++ b/test/unit/transform/CMakeLists.txt @@ -33,7 +33,7 @@ add_custom_target( cutlass_test_unit_transform DEPENDS cutlass_test_unit_transform_threadblock - cutlass_test_unit_transform_filter_format + cutlass_test_unit_transform_kernel ) add_custom_target( diff --git a/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp b/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp index 3a42d74f..8ec0c4ac 100644 --- a/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp +++ b/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp @@ -45,6 +45,7 @@ #include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor #include "cutlass/arch/arch.h" // cutlass::arch::Sm90 #include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/detail/collective.hpp" #include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t #include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up #include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo @@ -219,9 +220,7 @@ public: // * EltA using ElementA = ElementA_; using ElementAUint = cute::uint_bit_t>; - static constexpr bool IsRuntimeDataTypeA = cute::is_same_v || - cute::is_same_v || - cute::is_same_v; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); using ArrayElementA = cute::conditional_t>, ElementA>; diff --git a/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp b/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp index af50348e..03e4fa75 100644 --- a/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp +++ b/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp @@ -60,6 +60,7 @@ #include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride #include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals #include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill +#include "cutlass/detail/collective.hpp" #include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor #include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE diff --git a/test/unit/transform/kernel/CMakeLists.txt b/test/unit/transform/kernel/CMakeLists.txt index d337b31e..92d4a47b 100644 --- a/test/unit/transform/kernel/CMakeLists.txt +++ b/test/unit/transform/kernel/CMakeLists.txt @@ -27,6 +27,6 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. cutlass_test_unit_add_executable( - cutlass_test_unit_transform_filter_format + cutlass_test_unit_transform_kernel filter_format_transformer.cu ) diff --git a/test/unit/util/rms_norm.cu b/test/unit/util/rms_norm.cu index c366f7e0..41114067 100644 --- a/test/unit/util/rms_norm.cu +++ b/test/unit/util/rms_norm.cu @@ -104,7 +104,7 @@ void run_test(int M, int N) { for (int n = 0; n < N; ++n) { auto diff = abs(static_cast(output_ref.at({m, n}) - output.at({m, n}))); mean_abs_diff += diff; - max_abs_diff = max(max_abs_diff, diff); + max_abs_diff = cutlass::platform::max(max_abs_diff, diff); } } diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h index f55b5131..d87d0895 100644 --- a/tools/library/include/cutlass/library/handle.h +++ b/tools/library/include/cutlass/library/handle.h @@ -72,6 +72,8 @@ private: /// Pointer to the most recently executed operation Operation const *last_operation_; + int device_idx_; + public: /// Constructor diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 56e6e455..19812d4b 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -103,7 +103,6 @@ public: void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; - // Originally designed for metadata, but should be useful for FP8/6/4 too. virtual Status initialize_with_profiler_workspace( void const *configuration, void *host_workspace, @@ -118,7 +117,8 @@ public: void const *arguments, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const = 0; + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const = 0; }; @@ -272,6 +272,8 @@ struct GemmUniversalConfiguration { int64_t ldb{0}; int64_t ldc{0}; int64_t ldd{0}; + + int device_count{1}; }; struct GemmUniversalArguments { @@ -303,6 +305,8 @@ struct GemmUniversalArguments { int sm_count{0}; library::RasterOrder raster_order{}; int swizzle_size{1}; + + int device_index{0}; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/conv2d_operation.h b/tools/library/src/conv2d_operation.h index cf29c889..027b2615 100644 --- a/tools/library/src/conv2d_operation.h +++ b/tools/library/src/conv2d_operation.h @@ -326,7 +326,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -578,7 +583,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/conv3d_operation.h b/tools/library/src/conv3d_operation.h index 758866b8..6cb1796b 100644 --- a/tools/library/src/conv3d_operation.h +++ b/tools/library/src/conv3d_operation.h @@ -317,7 +317,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/conv_operation_3x.hpp b/tools/library/src/conv_operation_3x.hpp index 00931822..d6f79e91 100644 --- a/tools/library/src/conv_operation_3x.hpp +++ b/tools/library/src/conv_operation_3x.hpp @@ -236,12 +236,14 @@ public: typename Operator::Arguments out_args{}; status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_configuration_2d_or_3d failed"); return status; } auto* in_args_ptr = reinterpret_cast(arguments); status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_arguments failed"); return status; } @@ -332,7 +334,8 @@ public: void const* arguments, void* host_workspace, void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const override + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const override { auto status = Status::kInvalid; @@ -358,7 +361,7 @@ public: } auto* op = reinterpret_cast(host_workspace); - return op->run(out_args, device_workspace, stream); + return op->run(out_args, device_workspace, stream, nullptr, launch_with_pdl); } private: @@ -482,6 +485,11 @@ private: typename Operator::Arguments& out_args, Conv2dConfiguration const& config) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv2dConfiguration)\n"); +#endif using detail::vector_to_array_strides; constexpr int num_spatial_dims = Operator::NumSpatialDimensions; @@ -595,6 +603,7 @@ private: const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); + const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); // cutlass::library::Conv2dConfiguration has no member stride_d. // The code below imitates the testbed, @@ -605,12 +614,57 @@ private: CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); return Status::kInvalid; } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // This means that stride_act isn't always config.stride_A, + // depending on Fprop / Dgrad / Wgrad. The code here "undoes" + // the logic in Conv2dWorkspace::set_stride_vector so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + problem_shape_type problem_shape( /* mode = */ mode, - /* shape_act = */ {N, H, W, C}, - /* stride_act = */ stride_A, - /* shape_flt = */ {K, R, S, C}, - /* stride_flt = */ stride_B, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, /* lower_padding = */ {pad_h, pad_w}, /* upper_padding = */ {pad_h, pad_w}, /* traversal_stride = */ {traversal_stride_h, traversal_stride_w}, @@ -620,9 +674,11 @@ private: // ConvProblemShape's constructor sets its shape_C member. #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " problem_shape:\n" - << " shape_C: " << problem_shape.shape_C << "\n"; - std::cerr << " stride_C: " << problem_shape.stride_C << "\n"; + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); #endif // Initialization of C's and D's strides follows the CUTLASS 3 // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). @@ -670,6 +726,11 @@ private: typename Operator::Arguments& out_args, Conv3dConfiguration const& config) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv3dConfiguration)\n"); +#endif using detail::coord_to_array_strides; constexpr int num_spatial_dims = Operator::NumSpatialDimensions; @@ -762,6 +823,10 @@ private: print_stride(input_stride_b, "input_stride_b"); print_stride(input_stride_c, "input_stride_c"); #endif + // Conv3dConfiguration stores the strides as Coord (with + // compile-time size), so there's no need to check sizes here + // (unlike Conv2dConfiguration, which stores strides as + // std::vector). constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; using problem_shape_type = @@ -771,18 +836,68 @@ private: const TensorStride stride_A = coord_to_array_strides(input_stride_a); const TensorStride stride_B = coord_to_array_strides(input_stride_b); + const TensorStride stride_C = coord_to_array_strides(input_stride_c); const int num_groups = config.problem_size.groups; if (num_groups != 1) { CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); return Status::kInvalid; } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // Conv3dConfiguration differs a bit from Conv2dConfiguration, + // but the idea is the same: the "input_stride_a" from config + // depends on conv_kind (Fprop, Dgrad, or Wgrad), so stride_act + // isn't always input_stride_a. Analogously, stride_flt isn't + // always input_stride_b. The code here "undoes" the logic in + // config.layout_a(conv_kind) and config.layout_b(conv_kind) + // (analogous to Conv2dWorkspace::set_stride_vector) so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, D, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, T, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + problem_shape_type problem_shape( /* mode = */ mode, - /* shape_act = */ {N, D, H, W, C}, - /* stride_act = */ stride_A, - /* shape_flt = */ {K, T, R, S, C}, - /* stride_flt = */ stride_B, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, /* lower_padding = */ {pad_d, pad_h, pad_w}, /* upper_padding = */ {pad_d, pad_h, pad_w}, /* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w}, @@ -792,15 +907,15 @@ private: // ConvProblemShape's constructor sets its shape_C member. #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " problem_shape:\n" - << " shape_C: " << problem_shape.shape_C << "\n"; - std::cerr << " stride_C: " << problem_shape.stride_C << "\n"; + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); #endif - + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). { -#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " Compute stride_C and stride_D\n"; -#endif using StrideC = typename Operator::ConvKernel::StrideC; using StrideD = typename Operator::ConvKernel::StrideD; auto stride_C = StrideC{}; @@ -845,9 +960,8 @@ private: ConvArguments const& in_args) const { #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << "ConvOperation3x::update_operator_arguments_from_arguments\n"; + CUTLASS_TRACE_HOST("ConvOperation3x::update_operator_arguments_from_arguments\n"); #endif - auto status = UpdateFusionArgs::update_( out_args.epilogue.thread, in_args); if (status != Status::kSuccess) { diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 2c2c5c9c..5c6f9ca8 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -296,7 +296,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -500,7 +505,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -721,7 +731,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -930,7 +945,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -1133,7 +1153,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; @@ -1337,7 +1362,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index f4918b7d..7c87b45e 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -35,6 +35,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" #include "cutlass/library/library.h" #include "library_internal.h" #include "cutlass/gemm/dispatch_policy.hpp" @@ -331,7 +332,8 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const override { OperatorArguments args; Status status = update_arguments_(args, static_cast(arguments_ptr)); @@ -341,7 +343,7 @@ public: Operator *op = static_cast(host_workspace); // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(args, device_workspace, stream); + status = op->run(args, device_workspace, stream, nullptr, launch_with_pdl); return status; } }; diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index cfb176be..e6f00f72 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -57,14 +57,12 @@ Handle::Handle( scalar_pointer_mode_(ScalarPointerMode::kHost), last_operation_(nullptr) { - int device_idx = -1; - - cudaError_t error = cudaGetDevice(&device_idx); + cudaError_t error = cudaGetDevice(&device_idx_); if (error != cudaSuccess) { throw std::runtime_error("cudaGetDevice() failed"); } - error = cudaGetDeviceProperties(&device_, device_idx); + error = cudaGetDeviceProperties(&device_, device_idx_); if (error != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } @@ -78,8 +76,14 @@ Handle::Handle( Handle::~Handle() { if (workspace_) { - if (workspace_) { - cudaFree(workspace_); + int device_before; + cudaGetDevice(&device_before); + if (device_before != device_idx_) { + cudaSetDevice(device_idx_); + } + cudaFree(workspace_); + if (device_before != device_idx_) { + cudaSetDevice(device_before); } workspace_ = nullptr; @@ -89,6 +93,10 @@ Handle::~Handle() { /// Move constructor Handle::Handle(Handle && handle) { + cudaError_t error = cudaGetDevice(&device_idx_); + if (error != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } device_ = handle.device_; workspace_size_ = handle.workspace_size_; workspace_ = handle.workspace_; @@ -112,6 +120,8 @@ Handle & Handle::operator=(Handle && handle) { handle.workspace_ = nullptr; handle.workspace_size_ = 0; + device_idx_ = handle.device_idx_; + return *this; } @@ -151,6 +161,12 @@ void *Handle::get_workspace() const { /// Sets the size of device workspace, invalidating previous calls to get_device_workspace() void Handle::set_workspace_size(size_t bytes) { + int device_before; + cudaGetDevice(&device_before); + if (device_before != device_idx_) { + cudaSetDevice(device_idx_); + } + if (bytes != workspace_size_) { if (workspace_) { @@ -177,6 +193,9 @@ void Handle::set_workspace_size(size_t bytes) { throw std::runtime_error("Failed to clear workspace"); } } + if (device_before != device_idx_) { + cudaSetDevice(device_before); + } } /// Gets the scalar pointer mode diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index 2b57dbc3..be311c62 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -72,6 +72,10 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kB1; }; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS2; +}; + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kS4; }; @@ -92,6 +96,10 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kS64; }; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU2; +}; + template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kU4; }; diff --git a/tools/library/src/rank_2k_operation.h b/tools/library/src/rank_2k_operation.h index b353a347..5a611104 100644 --- a/tools/library/src/rank_2k_operation.h +++ b/tools/library/src/rank_2k_operation.h @@ -314,7 +314,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/rank_k_operation.h b/tools/library/src/rank_k_operation.h index e5e5ec78..e6afb1da 100644 --- a/tools/library/src/rank_k_operation.h +++ b/tools/library/src/rank_k_operation.h @@ -310,7 +310,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/reduction/reduction_operation.h b/tools/library/src/reduction/reduction_operation.h index ceb46278..3bcabf09 100644 --- a/tools/library/src/reduction/reduction_operation.h +++ b/tools/library/src/reduction/reduction_operation.h @@ -231,7 +231,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/reference/conv_reference_operation.h b/tools/library/src/reference/conv_reference_operation.h index ab924b5f..2bafc4af 100644 --- a/tools/library/src/reference/conv_reference_operation.h +++ b/tools/library/src/reference/conv_reference_operation.h @@ -432,7 +432,12 @@ public: void const *arguments, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } ConvArguments const &args = *static_cast(arguments); diff --git a/tools/library/src/reference/gemm_reference_operation.h b/tools/library/src/reference/gemm_reference_operation.h index fd58d4f0..940ff521 100644 --- a/tools/library/src/reference/gemm_reference_operation.h +++ b/tools/library/src/reference/gemm_reference_operation.h @@ -192,7 +192,12 @@ public: void const *arguments, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } GemmUniversalConfiguration const &config = *static_cast(host_workspace); GemmUniversalArguments const &args = *static_cast(arguments); diff --git a/tools/library/src/reference/gemm_s8_s8_s32.cu b/tools/library/src/reference/gemm_s8_s8_s32.cu index d88e986f..8c661b98 100644 --- a/tools/library/src/reference/gemm_s8_s8_s32.cu +++ b/tools/library/src/reference/gemm_s8_s8_s32.cu @@ -77,7 +77,7 @@ void initialize_gemm_reference_operations_s8_s8_s32(Manifest &manifest) { int8_t, // ElementA int8_t, // ElementB int32_t, // ElementC - int32_t, // ElementScalar / ElementCompute + float, // ElementScalar / ElementCompute int32_t, // ElementAccumulator int32_t // ElementD >(manifest); diff --git a/tools/library/src/sparse_gemm_operation_3x.hpp b/tools/library/src/sparse_gemm_operation_3x.hpp index fec987f5..8bfc41d7 100644 --- a/tools/library/src/sparse_gemm_operation_3x.hpp +++ b/tools/library/src/sparse_gemm_operation_3x.hpp @@ -35,6 +35,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" #include "cutlass/library/library.h" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor #include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter @@ -56,7 +57,7 @@ namespace cutlass::library { /////////////////////////////////////////////////////////////////////////////////////////////////// -// Limitation & Assumptions: +// Limitation & Assumptions: // 1. The tensor must be densely packed. That is, lda is k if the tensor is k-major, // and lda is m if the tensor is m-major. // 2. Circular buffer for tensorA and tensorE may have a less count compared to tensorB and others. @@ -169,7 +170,6 @@ protected: return status; } - // TODO: type erase Arguments structure in 3.0 GEMM operator_args.problem_shape = cute::make_shape( arguments->problem_size.m(), arguments->problem_size.n(), @@ -302,13 +302,15 @@ public: } Status initialize_with_profiler_workspace( - void const *configuration, - void *host_workspace, - void *device_workspace, + void const *configuration, + void *host_workspace, + void *device_workspace, uint8_t **profiler_workspaces, int problem_count_from_profiler, cudaStream_t stream = nullptr) { + iter_idx.resize(static_cast(configuration)->device_count, 0); + // Set problem_count. problem_count = problem_count_from_profiler; @@ -319,13 +321,10 @@ public: // * Construct Op Operator *op = new (host_op_workspace_ptr) Operator; - // * Device Full Ptr - device_full_ptr = reinterpret_cast(device_workspace); - // * Device Ptr (1st iteration) // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | // iteri : op_workspace | tensor_ac | tensor_e - auto* device_ptr_iter1 = device_full_ptr; + auto* device_ptr_iter1 = static_cast(device_workspace); auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; @@ -335,15 +334,15 @@ public: auto* device_a_raw_ptr = profiler_workspaces[0]; // * Random fill 50% of TensorA w/ zero following the structured sparse requirement - cudaMemcpy(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost); + CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); - cudaMemcpy(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice); + CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); CUDA_CHECK(cudaGetLastError()); // * Compress DTensorA and get DTensorAC & DTensorE cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; + CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); typename Compressor::Arguments arguments{ {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, @@ -372,23 +371,23 @@ public: return status; } - CUDA_CHECK(cudaStreamSynchronize(stream)); - // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE for (int iter_i = 1; iter_i < problem_count; iter_i++) { // * Device AC E Ptr per iteration // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | // iteri : op_workspace | tensor_ac | tensor_e - auto* device_ptr_iteri = device_full_ptr + device_per_iter_workspace_size * iter_i; + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; auto* device_op_workspace_ptr = device_ptr_iteri; auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; - cudaMemcpy(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice); - cudaMemcpy(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice); + CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); } + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaGetLastError()); return Status::kSuccess; @@ -398,17 +397,21 @@ public: Status run( void const *arguments_ptr, void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const override { + void *device_workspace, + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const override { OperatorArguments operator_args; - auto* device_ptr_iteri = device_full_ptr + device_per_iter_workspace_size * iter_idx; + + const auto device_index = static_cast(arguments_ptr)->device_index; + + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; auto* device_op_workspace_ptr = device_ptr_iteri; auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; - iter_idx = (iter_idx + 1) % problem_count; + iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr ); @@ -418,7 +421,7 @@ public: Operator *op = static_cast(host_workspace); // We need to call initialize() since we have to rebuild TMA desc for every new set of args - status = op->run(operator_args, device_op_workspace_ptr, stream); + status = op->run(operator_args, device_op_workspace_ptr, stream, nullptr, launch_with_pdl); return status; } @@ -426,9 +429,7 @@ private: // Variables that must change in the const functions. mutable CompressorUtility compressor_utility; mutable int problem_count = 1; - mutable int iter_idx = 0; - - uint8_t* device_full_ptr = nullptr; + mutable std::vector iter_idx; mutable uint64_t tensor_ac_size = 0; mutable uint64_t tensor_e_size = 0; diff --git a/tools/library/src/symm_operation.h b/tools/library/src/symm_operation.h index 548356b4..aeb06caf 100644 --- a/tools/library/src/symm_operation.h +++ b/tools/library/src/symm_operation.h @@ -312,7 +312,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/library/src/trmm_operation.h b/tools/library/src/trmm_operation.h index 80e4ad14..88c4f7ab 100644 --- a/tools/library/src/trmm_operation.h +++ b/tools/library/src/trmm_operation.h @@ -304,7 +304,12 @@ public: void const *arguments_ptr, void *host_workspace, void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + cudaStream_t stream = nullptr, + bool launch_with_pdl = false) const { + + if (launch_with_pdl) { + return Status::kErrorNotSupported; + } OperatorArguments args; diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 1a7fb128..d71caf41 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -100,7 +100,7 @@ install( if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 AND (90a IN_LIST CUTLASS_NVCC_ARCHS_ENABLED OR (90 IN_LIST CUTLASS_NVCC_ARCHS_ENABLED))) set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) else() - set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) endif() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true) diff --git a/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h b/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h index 24975013..32d79211 100644 --- a/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h @@ -430,7 +430,7 @@ public: protected: /// Method to profile an initialized CUTLASS operation virtual Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h b/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h index 3cbf3106..2ce0a1c2 100644 --- a/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h @@ -384,7 +384,7 @@ protected: /// Method to profile an initialized CUTLASS operation virtual Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/cublas_helpers.h b/tools/profiler/include/cutlass/profiler/cublas_helpers.h index 32875245..10642e5f 100644 --- a/tools/profiler/include/cutlass/profiler/cublas_helpers.h +++ b/tools/profiler/include/cutlass/profiler/cublas_helpers.h @@ -304,7 +304,7 @@ struct cublasLtGemmExDispatcher { ); /// Executes GEMM using these arguments - cublasStatus_t operator()(cublasLtHandle_t handle); + cublasStatus_t operator()(cublasLtHandle_t handle, cudaStream_t stream = nullptr); ~cublasLtGemmExDispatcher(){ diff --git a/tools/profiler/include/cutlass/profiler/enumerated_types.h b/tools/profiler/include/cutlass/profiler/enumerated_types.h index 25a42296..3e6efa48 100644 --- a/tools/profiler/include/cutlass/profiler/enumerated_types.h +++ b/tools/profiler/include/cutlass/profiler/enumerated_types.h @@ -90,9 +90,9 @@ AlgorithmMode from_string(std::string const &str); /// Outcome of a performance test enum class Disposition { kPassed, - kFailed, + kFailed, // kernel itself reported an error kNotRun, - kIncorrect, + kIncorrect, // kernel finished without a detected error, but result does not equal expected result kNotVerified, kInvalidProblem, kNotSupported, diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 8e1292f9..b103e3db 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -143,6 +143,8 @@ public: /// Buffer used for the cutlass reduction operations' host workspace std::vector reduction_host_workspace; + + cudaStream_t stream; }; protected: @@ -155,7 +157,7 @@ protected: GemmProblem problem_; /// Device memory allocations - GemmWorkspace gemm_workspace_; + std::vector gemm_workspace_; /// CUTLASS parallel reduction operation to follow this* gemm operation library::Operation const *reduction_op_; @@ -231,7 +233,8 @@ protected: DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace); /// Verifies CUTLASS against host and device references bool verify_with_reference_( @@ -246,7 +249,7 @@ protected: /// Method to profile a CUTLASS Operation Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/gpu_timer.h b/tools/profiler/include/cutlass/profiler/gpu_timer.h index 304a362a..815b6af1 100644 --- a/tools/profiler/include/cutlass/profiler/gpu_timer.h +++ b/tools/profiler/include/cutlass/profiler/gpu_timer.h @@ -51,16 +51,21 @@ struct GpuTimer { // GpuTimer(); + + GpuTimer(GpuTimer const&) = delete; + + GpuTimer(GpuTimer &&gpu_timer) noexcept; + ~GpuTimer(); - /// Records a start event in the stream - void start(cudaStream_t stream = nullptr); + /// Records a start event in the stream, the flag is for cudaEventRecordWithFlags + void start(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); - /// Records a stop event in the stream - void stop(cudaStream_t stream = nullptr); + /// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags + void stop(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); - /// Records a stop event in the stream and synchronizes on the stream - void stop_and_wait(cudaStream_t stream = nullptr); + /// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags + void stop_and_wait(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); /// Returns the duration in milliseconds double duration(int iterations = 1) const; diff --git a/tools/profiler/include/cutlass/profiler/operation_profiler.h b/tools/profiler/include/cutlass/profiler/operation_profiler.h index 3dfe3fcf..7e3005fe 100644 --- a/tools/profiler/include/cutlass/profiler/operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/operation_profiler.h @@ -232,7 +232,7 @@ protected: /// Method to profile an initialized CUTLASS operation virtual Status profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, diff --git a/tools/profiler/include/cutlass/profiler/performance_result.h b/tools/profiler/include/cutlass/profiler/performance_result.h index fb3393c4..4b9a3321 100644 --- a/tools/profiler/include/cutlass/profiler/performance_result.h +++ b/tools/profiler/include/cutlass/profiler/performance_result.h @@ -86,6 +86,9 @@ struct PerformanceResult { /// Average runtime in ms double runtime; + /// Average runtime in ms per device + std::vector runtime_vector; + // // Members // diff --git a/tools/profiler/src/conv2d_operation_profiler.cu b/tools/profiler/src/conv2d_operation_profiler.cu index f74ffbe7..9589c0ca 100644 --- a/tools/profiler/src/conv2d_operation_profiler.cu +++ b/tools/profiler/src/conv2d_operation_profiler.cu @@ -396,6 +396,29 @@ Status Conv2dOperationProfiler::initialize_configuration( problem_, operation_desc.conv_kind, operation_desc.A.layout, operation_desc.B.layout, operation_desc.C.layout); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + { + auto print_vector = [] (const auto& vec) { + printf("["); + for (size_t k = 0; k < vec.size(); ++k) { + cute::print(vec[k]); + if (k + 1 < vec.size()) { + printf(","); + } + } + printf("]"); + }; + + printf("\n conv_workspace_.configuration.stride_a: "); + print_vector(conv_workspace_.configuration.stride_a); + printf("\n conv_workspace_.configuration.stride_b: "); + print_vector(conv_workspace_.configuration.stride_b); + printf("\n conv_workspace_.configuration.stride_c: "); + print_vector(conv_workspace_.configuration.stride_c); + printf("\n"); + } +#endif + // initialize library::ConvArguments conv_workspace_.arguments.A = nullptr; conv_workspace_.arguments.B = nullptr; @@ -1237,7 +1260,7 @@ bool Conv2dOperationProfiler::profile( } results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &conv_workspace_.arguments, @@ -1251,7 +1274,7 @@ bool Conv2dOperationProfiler::profile( /// Method to profile a CUTLASS Operation Status Conv2dOperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, @@ -1387,7 +1410,7 @@ Status Conv2dOperationProfiler::profile_cutlass_( // Update performance result // - runtime = timer.duration(iteration); + result.runtime = timer.duration(iteration); return status; } diff --git a/tools/profiler/src/conv3d_operation_profiler.cu b/tools/profiler/src/conv3d_operation_profiler.cu index 8e8f5873..04d338c3 100644 --- a/tools/profiler/src/conv3d_operation_profiler.cu +++ b/tools/profiler/src/conv3d_operation_profiler.cu @@ -1099,7 +1099,7 @@ bool Conv3dOperationProfiler::profile( set_cutlass_operator_arguments_(); results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &conv_workspace_.arguments, @@ -1141,7 +1141,7 @@ void Conv3dOperationProfiler::set_cutlass_operator_arguments_(int problem_idx) { /// Method to profile a CUTLASS Operation Status Conv3dOperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, @@ -1248,7 +1248,7 @@ Status Conv3dOperationProfiler::profile_cutlass_( // Update performance result // - runtime = timer.duration(iteration); + result.runtime = timer.duration(iteration); return status; } diff --git a/tools/profiler/src/cublas_helpers.cu b/tools/profiler/src/cublas_helpers.cu index 7467c1db..412b0a24 100644 --- a/tools/profiler/src/cublas_helpers.cu +++ b/tools/profiler/src/cublas_helpers.cu @@ -656,7 +656,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle, return true; } -cublasStatus_t cublasLtGemmExDispatcher::operator()(cublasLtHandle_t handle) +cublasStatus_t cublasLtGemmExDispatcher::operator()(cublasLtHandle_t handle, cudaStream_t stream) { return cublasLtMatmul(handle, operationDesc, @@ -673,7 +673,7 @@ cublasStatus_t cublasLtGemmExDispatcher::operator()(cublasLtHandle_t handle) &heuristicResult_.algo, workspace, heuristicResult_.workspaceSize, - 0); //number of streams is set to 0 + stream); //number of streams is set to 0 } diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 4e57244e..a1866b55 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -290,9 +290,8 @@ DeviceAllocation::DeviceAllocation(): capacity_(0), pointer_(nullptr), layout_(library::LayoutTypeID::kUnknown), - batch_count_(1), - device_(-1) { - + batch_count_(1) { + cudaGetDevice(&device_); } DeviceAllocation::DeviceAllocation( @@ -329,13 +328,33 @@ DeviceAllocation::DeviceAllocation( DeviceAllocation::~DeviceAllocation() { if (pointer_) { + int current_device; + cudaGetDevice(¤t_device); + + if (current_device != device_) { + cudaSetDevice(device_); + } cudaFree(pointer_); + + if (current_device != device_) { + cudaSetDevice(current_device); + } } } DeviceAllocation &DeviceAllocation::reset() { if (pointer_) { + int current_device; + cudaGetDevice(¤t_device); + + if (current_device != device_) { + cudaSetDevice(device_); + } cudaFree(pointer_); + + if (current_device != device_) { + cudaSetDevice(current_device); + } } type_ = library::NumericTypeID::kInvalid; @@ -2438,25 +2457,11 @@ void DeviceAllocation::fill_host(double val = 0.0) { cudaError_t DeviceAllocation::malloc(void** ptr, size_t size) { cudaError_t result; - int set_device_back_to = -1; + int current_device; + cudaGetDevice(¤t_device); - /// When needed this sets the device to the allocation's device remembering - /// the current device so that it can be set back after the cudaMalloc is - /// performed. - if (device_ >= 0) { - int current_device; - result = cudaGetDevice(¤t_device); - if (result != cudaSuccess) { - return result; - } - - if (current_device != device_) { - set_device_back_to = current_device; - result = cudaSetDevice(device_); - if (result != cudaSuccess) { - return result; - } - } + if (current_device != device_) { + cudaSetDevice(device_); } // This performs the cudaMalloc @@ -2465,13 +2470,8 @@ cudaError_t DeviceAllocation::malloc(void** ptr, size_t size) { return result; } - /// When needed this sets the device back to what it was when the function was - /// called. - if (set_device_back_to != -1) { - result = cudaSetDevice(set_device_back_to); - if (result != cudaSuccess) { - return result; - } + if (current_device != device_) { + cudaSetDevice(current_device); } return cudaSuccess; diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 0256f0d0..1bed599f 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -40,6 +40,7 @@ #include "cutlass/core_io.h" #include +#include #include "cutlass/profiler/cublas_helpers.h" #include "cutlass/profiler/gemm_operation_profiler.h" @@ -195,9 +196,6 @@ Status GemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->swizzle_size, "swizzle_size", problem_space, problem)) { // default value this->swizzle_size = 1; - if (this->swizzle_size <= 0) { - return Status::kErrorInvalidProblem; - } } if (!arg_as_RasterOrder(this->raster_order, "raster_order", problem_space, problem)) { @@ -371,31 +369,49 @@ Status GemmOperationProfiler::initialize_configuration( return status; } - gemm_workspace_.configuration.mode = problem_.mode; - gemm_workspace_.configuration.problem_size.m() = int(problem_.m); - gemm_workspace_.configuration.problem_size.n() = int(problem_.n); - gemm_workspace_.configuration.problem_size.k() = int(problem_.k); - gemm_workspace_.configuration.lda = problem_.lda; - gemm_workspace_.configuration.ldb = problem_.ldb; - gemm_workspace_.configuration.ldc = problem_.ldc; - gemm_workspace_.configuration.ldd = problem_.ldc; + const auto device_count = options.device.devices.size(); - if (problem_.mode == library::GemmUniversalMode::kBatched) { - gemm_workspace_.configuration.batch_count = problem_.batch_count; - } - else { - gemm_workspace_.configuration.batch_count = problem_.split_k_slices; + gemm_workspace_.clear(); + + for (size_t i = 0; i < device_count; ++i) { + cudaSetDevice(options.device.device_id(i)); + gemm_workspace_.emplace_back(); + cudaStreamCreateWithFlags(&gemm_workspace_[i].stream, cudaStreamNonBlocking); + gemm_workspace_[i].configuration.mode = problem_.mode; + gemm_workspace_[i].configuration.problem_size.m() = int(problem_.m); + gemm_workspace_[i].configuration.problem_size.n() = int(problem_.n); + gemm_workspace_[i].configuration.problem_size.k() = int(problem_.k); + gemm_workspace_[i].configuration.lda = problem_.lda; + gemm_workspace_[i].configuration.ldb = problem_.ldb; + gemm_workspace_[i].configuration.ldc = problem_.ldc; + gemm_workspace_[i].configuration.ldd = problem_.ldc; + + gemm_workspace_[i].configuration.device_count = static_cast(device_count); + gemm_workspace_[i].arguments.device_index = static_cast(i); + + if (problem_.mode == library::GemmUniversalMode::kBatched) { + gemm_workspace_[i].configuration.batch_count = problem_.batch_count; + } + else { + gemm_workspace_[i].configuration.batch_count = problem_.split_k_slices; + } + + gemm_workspace_[i].arguments.A = nullptr; + gemm_workspace_[i].arguments.B = nullptr; + gemm_workspace_[i].arguments.C = nullptr; + gemm_workspace_[i].arguments.D = nullptr; + gemm_workspace_[i].arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].arguments.beta = problem_.beta.data(); + gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].arguments.swizzle_size = problem_.swizzle_size; + gemm_workspace_[i].arguments.raster_order = problem_.raster_order; + initialize_result_(this->model_result_, options, operation_desc, problem_space); + + if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) { + return can_implement; + } } - gemm_workspace_.arguments.A = nullptr; - gemm_workspace_.arguments.B = nullptr; - gemm_workspace_.arguments.C = nullptr; - gemm_workspace_.arguments.D = nullptr; - gemm_workspace_.arguments.alpha = problem_.alpha.data(); - gemm_workspace_.arguments.beta = problem_.beta.data(); - gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - gemm_workspace_.arguments.swizzle_size = problem_.swizzle_size; - gemm_workspace_.arguments.raster_order = problem_.raster_order; // initialize reduction operation for parallel splitKMode if (problem_.split_k_mode == library::SplitKMode::kParallel) { if (!initialize_reduction_configuration_(operation, problem)) { @@ -403,9 +419,7 @@ Status GemmOperationProfiler::initialize_configuration( } } - initialize_result_(this->model_result_, options, operation_desc, problem_space); - - return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); + return status; } /// Initializes the performance result @@ -427,6 +441,7 @@ void GemmOperationProfiler::initialize_result_( result.bytes = problem_.bytes(operation_desc); result.flops = problem_.flops(operation_desc); result.runtime = 0; + result.runtime_vector.resize(options.device.devices.size(), 0); } @@ -447,12 +462,14 @@ bool GemmOperationProfiler::initialize_reduction_configuration_( } /// initialize library::ReductionConfiguration - gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); - gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices); - gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); - gemm_workspace_.reduction_configuration.ldw = problem_.ldc; - gemm_workspace_.reduction_configuration.lds = problem_.ldc; - gemm_workspace_.reduction_configuration.ldd = problem_.ldc; + for (auto &gemm_workspace : gemm_workspace_) { + gemm_workspace.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn(); + gemm_workspace.reduction_configuration.partitions = int(problem_.split_k_slices); + gemm_workspace.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product(); + gemm_workspace.reduction_configuration.ldw = problem_.ldc; + gemm_workspace.reduction_configuration.lds = problem_.ldc; + gemm_workspace.reduction_configuration.ldd = problem_.ldc; + } // find reduction operation library::ReductionFunctionalKey reduction_key( @@ -485,11 +502,6 @@ Status GemmOperationProfiler::initialize_workspace( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - if (options.device.devices.size() != 1) { - throw std::runtime_error("This operation profiler only supports a single " - "device."); - } - cudaError_t result; result = cudaSetDevice(options.device.device_id(0)); if (result != cudaSuccess) { @@ -509,98 +521,103 @@ Status GemmOperationProfiler::initialize_workspace( bool is_sparse = operation_desc.tile_description.math_instruction.opcode_class == cutlass::library::OpcodeClassID::kSparseTensorOp; - // Compute the number of copies of the problem to avoid L2 camping. - if (!options.profiling.workspace_count) { - int64_t bytes = problem_.bytes(operation_desc); - if (bytes < 3 * int64_t(options.device.properties[0].l2CacheSize)) { - gemm_workspace_.problem_count = - 1 + int((3 * int64_t(options.device.properties[0].l2CacheSize)) / bytes); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + + // Compute the number of copies of the problem to avoid L2 camping. + if (!options.profiling.workspace_count) { + int64_t bytes = problem_.bytes(operation_desc); + if (bytes < 3 * int64_t(options.device.properties[0].l2CacheSize)) { + gemm_workspace_[i].problem_count = + 1 + int((3 * int64_t(options.device.properties[0].l2CacheSize)) / bytes); + } + else { + gemm_workspace_[i].problem_count = 1; + } } else { - gemm_workspace_.problem_count = 1; + gemm_workspace_[i].problem_count = options.profiling.workspace_count; } - } - else { - gemm_workspace_.problem_count = options.profiling.workspace_count; - } - bool allocate_device_tensors = options.execution_mode != ExecutionMode::kDryRun; - if (allocate_device_tensors) { - int seed_shift = 0; - gemm_workspace_.A = device_context.allocate_and_initialize_tensor( - options, - "A", - operation_desc.A.element, - operation_desc.A.layout, - {int(problem_.m), int(problem_.k)}, - {int(problem_.lda)}, - problem_.batch_count * gemm_workspace_.problem_count, - seed_shift++, - 0 // device_index - ); + bool allocate_device_tensors = options.execution_mode != ExecutionMode::kDryRun; + if (allocate_device_tensors) { + int seed_shift = 0; + gemm_workspace_[i].A = device_context.allocate_and_initialize_tensor( + options, + "A", + operation_desc.A.element, + operation_desc.A.layout, + {int(problem_.m), int(problem_.k)}, + {int(problem_.lda)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + seed_shift++, + i // device_index + ); - gemm_workspace_.B = device_context.allocate_and_initialize_tensor( - options, - "B", - operation_desc.B.element, - operation_desc.B.layout, - {int(problem_.k), int(problem_.n)}, - {int(problem_.ldb)}, - problem_.batch_count * gemm_workspace_.problem_count, - seed_shift++, - 0 // device_index - ); + gemm_workspace_[i].B = device_context.allocate_and_initialize_tensor( + options, + "B", + operation_desc.B.element, + operation_desc.B.layout, + {int(problem_.k), int(problem_.n)}, + {int(problem_.ldb)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + seed_shift++, + i // device_index + ); - gemm_workspace_.C = device_context.allocate_and_initialize_tensor( - options, - "C", - operation_desc.C.element, - operation_desc.C.layout, - {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)}, - problem_.batch_count * gemm_workspace_.problem_count, - seed_shift++, - 0 // device_index - ); + gemm_workspace_[i].C = device_context.allocate_and_initialize_tensor( + options, + "C", + operation_desc.C.element, + operation_desc.C.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + seed_shift++, + i // device_index + ); - gemm_workspace_.Computed = device_context.allocate_tensor( - options, - "D", - operation_desc.D.element, - operation_desc.D.layout, - {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)}, - problem_.batch_count * gemm_workspace_.problem_count, - 0 // device_index - ); + gemm_workspace_[i].Computed = device_context.allocate_tensor( + options, + "D", + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); - gemm_workspace_.Reference = device_context.allocate_tensor( - options, - "Reference", - operation_desc.D.element, - operation_desc.D.layout, - {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)}, - problem_.batch_count * gemm_workspace_.problem_count, - 0 // device_index - ); - } + gemm_workspace_[i].Reference = device_context.allocate_tensor( + options, + "Reference", + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + } - if (options.execution_mode != ExecutionMode::kDryRun) { - // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels - gemm_workspace_.arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; - gemm_workspace_.arguments.batch_count = problem_.batch_count; - gemm_workspace_.arguments.lda = problem_.lda; - gemm_workspace_.arguments.ldb = problem_.ldb; - gemm_workspace_.arguments.ldc = problem_.ldc; - gemm_workspace_.arguments.ldd = problem_.ldc; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + if (options.execution_mode != ExecutionMode::kDryRun) { + // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels + gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; + gemm_workspace_[i].arguments.batch_count = problem_.batch_count; + gemm_workspace_[i].arguments.lda = problem_.lda; + gemm_workspace_[i].arguments.ldb = problem_.ldb; + gemm_workspace_[i].arguments.ldc = problem_.ldc; + gemm_workspace_[i].arguments.ldd = problem_.ldc; + gemm_workspace_[i].arguments.batch_stride_A = gemm_workspace_[i].A->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_B = gemm_workspace_[i].B->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); - /* Query device SM count to pass onto the kernel as an argument, where needed */ - gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount; + /* Query device SM count to pass onto the kernel as an argument, where needed */ + gemm_workspace_[i].arguments.sm_count = options.device.properties[0].multiProcessorCount; + gemm_workspace_[i].arguments.device_index = static_cast(i); + } } // @@ -611,58 +628,69 @@ Status GemmOperationProfiler::initialize_workspace( if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { if (options.execution_mode != ExecutionMode::kDryRun) { - uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); - gemm_workspace_.host_workspace.resize(workspace_size, 0); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_[i].configuration); + gemm_workspace_[i].host_workspace.resize(workspace_size, 0); - workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, - &gemm_workspace_.arguments); - if (is_sparse) { - // sparse gemm get_device_workspace_size() only return device workspace size per iteration - // Needs to multiply it w/ number of iteration - workspace_size *= gemm_workspace_.problem_count; - } - gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_[i].configuration, + &gemm_workspace_[i].arguments); + if (is_sparse) { + // sparse gemm get_device_workspace_size() only return device workspace size per iteration + // Needs to multiply it w/ number of iteration + workspace_size *= gemm_workspace_[i].problem_count; + } + gemm_workspace_[i].device_workspace.reset(library::NumericTypeID::kU8, workspace_size); - // Convert to structure sparse contents here. - if (is_sparse) { - uint8_t* profiler_workspaces[1]; - profiler_workspaces[0] = reinterpret_cast(gemm_workspace_.A->data()); - // Sparse operations have a different initialize interface. - // initialize_with_profiler_workspace converts mxk tensorA to compressed mxk/sp tensorA and the tensorE - auto modifiable_underlying_op = const_cast(underlying_operation); - status = modifiable_underlying_op->initialize_with_profiler_workspace( - &gemm_workspace_.configuration, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data(), - profiler_workspaces, - gemm_workspace_.problem_count); - } - else { - status = underlying_operation->initialize( - &gemm_workspace_.configuration, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data()); - } - - if (status != Status::kSuccess) { - return status; - } - - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration); - gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0); - - status = reduction_op_->initialize( - &gemm_workspace_.reduction_configuration, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); + // Convert to structure sparse contents here. + if (is_sparse) { + uint8_t* profiler_workspaces[1]; + profiler_workspaces[0] = reinterpret_cast(gemm_workspace_[i].A->data()); + // Sparse operations have a different initialize interface. + // initialize_with_profiler_workspace converts mxk tensorA to compressed mxk/sp tensorA and the tensorE + auto modifiable_underlying_op = const_cast(underlying_operation); + status = modifiable_underlying_op->initialize_with_profiler_workspace( + &gemm_workspace_[i].configuration, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + profiler_workspaces, + gemm_workspace_[i].problem_count, + gemm_workspace_[i].stream); + } + else { + status = underlying_operation->initialize( + &gemm_workspace_[i].configuration, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); + } if (status != Status::kSuccess) { return status; } + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_[i].reduction_configuration); + gemm_workspace_[i].reduction_host_workspace.resize(workspace_size, 0); + + status = reduction_op_->initialize( + &gemm_workspace_[i].reduction_configuration, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); + + if (status != Status::kSuccess) { + return status; + } + } } } + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaDeviceSynchronize(); + } + // // If CUTLASS is enabled, generate a result for it // @@ -698,29 +726,31 @@ bool GemmOperationProfiler::verify_cutlass( } // Initialize structure containing GEMM arguments - gemm_workspace_.arguments.A = gemm_workspace_.A->data(); - gemm_workspace_.arguments.B = gemm_workspace_.B->data(); - gemm_workspace_.arguments.C = gemm_workspace_.C->data(); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); - gemm_workspace_.arguments.alpha = problem_.alpha.data(); - gemm_workspace_.arguments.beta = problem_.beta.data(); - gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->data(); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->data(); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->data(); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].arguments.beta = problem_.beta.data(); + gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].arguments.batch_stride_A = gemm_workspace_[i].A->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_B = gemm_workspace_[i].B->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); - gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); - gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_[i].arguments.beta = problem_.beta_zero.data(); - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); - gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); - gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); - gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->data(); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_[i].reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } } // @@ -737,27 +767,33 @@ bool GemmOperationProfiler::verify_cutlass( } } - results_.back().status = underlying_operation->run( - &gemm_workspace_.arguments, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data()); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); - if (results_.back().status != Status::kSuccess) { - results_.back().disposition = Disposition::kFailed; - return false; - } - - // Run parallel reduction kernel for parallel split_k_mode - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - results_.back().status = reduction_op_->run( - &gemm_workspace_.reduction_arguments, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); + results_.back().status = underlying_operation->run( + &gemm_workspace_[i].arguments, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); if (results_.back().status != Status::kSuccess) { results_.back().disposition = Disposition::kFailed; return false; } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + results_.back().status = reduction_op_->run( + &gemm_workspace_[i].reduction_arguments, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); + + if (results_.back().status != Status::kSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + } } cudaError_t result = cudaDeviceSynchronize(); @@ -784,13 +820,17 @@ bool GemmOperationProfiler::verify_cutlass( if (cublas_satisfies(gemm_desc) == Status::kSuccess) { // call cublas verification if supported - verify_with_cublas_( - options, - report, - device_context, - operation, - problem_space, - problem); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + verify_with_cublas_( + options, + report, + device_context, + operation, + problem_space, + problem, + gemm_workspace_[i]); + } } else { @@ -852,7 +892,8 @@ bool GemmOperationProfiler::verify_with_cublas_( DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) { + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace_) { #if CUTLASS_ENABLE_CUBLAS @@ -983,115 +1024,119 @@ bool GemmOperationProfiler::verify_with_reference_( continue; } - void *ptr_A = gemm_workspace_.A->data(); - void *ptr_B = gemm_workspace_.B->data(); - void *ptr_C = gemm_workspace_.C->data(); - void *ptr_D = gemm_workspace_.Reference->data(); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); - // To support the host-side reference, conditionally allocate and - // copy tensors to host memory. - std::vector host_data_A; - std::vector host_data_B; - std::vector host_data_C; - std::vector host_data_D; + void *ptr_A = gemm_workspace_[i].A->data(); + void *ptr_B = gemm_workspace_[i].B->data(); + void *ptr_C = gemm_workspace_[i].C->data(); + void *ptr_D = gemm_workspace_[i].Reference->data(); - if (provider == library::Provider::kReferenceHost) { + // To support the host-side reference, conditionally allocate and + // copy tensors to host memory. + std::vector host_data_A; + std::vector host_data_B; + std::vector host_data_C; + std::vector host_data_D; - host_data_A.resize(gemm_workspace_.A->bytes()); - ptr_A = host_data_A.data(); - gemm_workspace_.A->copy_to_host(ptr_A); + if (provider == library::Provider::kReferenceHost) { - host_data_B.resize(gemm_workspace_.B->bytes()); - ptr_B = host_data_B.data(); - gemm_workspace_.B->copy_to_host(ptr_B); + host_data_A.resize(gemm_workspace_[i].A->bytes()); + ptr_A = host_data_A.data(); + gemm_workspace_[i].A->copy_to_host(ptr_A); - host_data_C.resize(gemm_workspace_.C->bytes()); - ptr_C = host_data_C.data(); - gemm_workspace_.C->copy_to_host(ptr_C); + host_data_B.resize(gemm_workspace_[i].B->bytes()); + ptr_B = host_data_B.data(); + gemm_workspace_[i].B->copy_to_host(ptr_B); - host_data_D.resize(gemm_workspace_.Reference->bytes()); - ptr_D = host_data_D.data(); - } + host_data_C.resize(gemm_workspace_[i].C->bytes()); + ptr_C = host_data_C.data(); + gemm_workspace_[i].C->copy_to_host(ptr_C); - // - // Launch - // + host_data_D.resize(gemm_workspace_[i].Reference->bytes()); + ptr_D = host_data_D.data(); + } - library::Handle handle; + // + // Launch + // - handle.set_provider(provider); + library::Handle handle; - Status status = handle.gemm_universal( - problem_.mode, - gemm_workspace_.configuration.problem_size.m(), - gemm_workspace_.configuration.problem_size.n(), - gemm_workspace_.configuration.problem_size.k(), - gemm_desc.tile_description.math_instruction.element_accumulator, - gemm_desc.element_epilogue, + handle.set_provider(provider); - problem_.alpha.data(), + Status status = handle.gemm_universal( + problem_.mode, + gemm_workspace_[i].configuration.problem_size.m(), + gemm_workspace_[i].configuration.problem_size.n(), + gemm_workspace_[i].configuration.problem_size.k(), + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, - element_A, - gemm_desc.A.layout, - gemm_desc.transform_A, - ptr_A, - int(gemm_workspace_.configuration.lda), + problem_.alpha.data(), - element_B, - gemm_desc.B.layout, - gemm_desc.transform_B, - ptr_B, - int(gemm_workspace_.configuration.ldb), + element_A, + gemm_desc.A.layout, + gemm_desc.transform_A, + ptr_A, + int(gemm_workspace_[i].configuration.lda), - problem_.beta.data(), + element_B, + gemm_desc.B.layout, + gemm_desc.transform_B, + ptr_B, + int(gemm_workspace_[i].configuration.ldb), - gemm_desc.C.element, - gemm_desc.C.layout, - ptr_C, - int(gemm_workspace_.configuration.ldc), + problem_.beta.data(), - gemm_desc.D.element, - gemm_desc.D.layout, - ptr_D, - int(gemm_workspace_.configuration.ldd), + gemm_desc.C.element, + gemm_desc.C.layout, + ptr_C, + int(gemm_workspace_[i].configuration.ldc), - gemm_workspace_.configuration.batch_count, - gemm_workspace_.A->batch_stride(), - gemm_workspace_.B->batch_stride(), - gemm_workspace_.C->batch_stride(), - gemm_workspace_.Reference->batch_stride()); + gemm_desc.D.element, + gemm_desc.D.layout, + ptr_D, + int(gemm_workspace_[i].configuration.ldd), - if (status != Status::kSuccess) { - results_.back().verification_map[provider] = Disposition::kNotRun; - continue; - } - results_.back().status = status; + gemm_workspace_[i].configuration.batch_count, + gemm_workspace_[i].A->batch_stride(), + gemm_workspace_[i].B->batch_stride(), + gemm_workspace_[i].C->batch_stride(), + gemm_workspace_[i].Reference->batch_stride()); - if (provider == library::Provider::kReferenceHost) { - gemm_workspace_.Reference->copy_from_host(ptr_D); - } + if (status != Status::kSuccess) { + results_.back().verification_map[provider] = Disposition::kNotRun; + continue; + } + results_.back().status = status; - // - // Verify results - // + if (provider == library::Provider::kReferenceHost) { + gemm_workspace_[i].Reference->copy_from_host(ptr_D); + } - results_.back().verification_map[provider] = compare_tensors( - options, - *gemm_workspace_.Computed, - *gemm_workspace_.Reference, - gemm_workspace_.Computed->batch_stride() - ); + // + // Verify results + // - // Save workspace if incorrect - if (options.verification.save_workspace == SaveWorkspace::kIncorrect && - results_.back().verification_map[provider] == Disposition::kIncorrect) { - - save_workspace( - device_context, + results_.back().verification_map[provider] = compare_tensors( options, - gemm_desc, - library::Provider::kCUTLASS, - provider); + *gemm_workspace_[i].Computed, + *gemm_workspace_[i].Reference, + gemm_workspace_[i].Computed->batch_stride() + ); + + // Save workspace if incorrect + if (options.verification.save_workspace == SaveWorkspace::kIncorrect && + results_.back().verification_map[provider] == Disposition::kIncorrect) { + + save_workspace( + device_context, + options, + gemm_desc, + library::Provider::kCUTLASS, + provider); + } } } @@ -1100,6 +1145,18 @@ bool GemmOperationProfiler::verify_with_reference_( ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace { +extern "C" { + __global__ void delay(cuda::atomic const* release) { + while (release->load(cuda::memory_order_acquire) != true) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) + __nanosleep(100); +#endif + } + } +} +} + /// Measures performance results bool GemmOperationProfiler::profile( Options const &options, @@ -1111,39 +1168,41 @@ bool GemmOperationProfiler::profile( if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { - // Initialize structure containing GEMM arguments - gemm_workspace_.arguments.A = gemm_workspace_.A->data(); - gemm_workspace_.arguments.B = gemm_workspace_.B->data(); - gemm_workspace_.arguments.C = gemm_workspace_.C->data(); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); - gemm_workspace_.arguments.alpha = problem_.alpha.data(); - gemm_workspace_.arguments.beta = problem_.beta.data(); - gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + // Initialize structure containing GEMM arguments + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->data(); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->data(); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->data(); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].arguments.beta = problem_.beta.data(); + gemm_workspace_[i].arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].arguments.batch_stride_A = gemm_workspace_[i].A->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_B = gemm_workspace_[i].B->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride(); + gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); - gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); - gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_[i].arguments.beta = problem_.beta_zero.data(); - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); - gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); - gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); - gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->data(); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->data(); + gemm_workspace_[i].reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_[i].reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_[i].reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } } results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, - &gemm_workspace_.arguments, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data() + nullptr, + nullptr, + nullptr ); } return true; @@ -1153,14 +1212,22 @@ bool GemmOperationProfiler::profile( /// Method to profile a CUTLASS Operation Status GemmOperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, - void *arguments, - void *host_workspace, - void *device_workspace) { + void *, + void *, + void *) { - GpuTimer timer; + cuda::atomic *release; + cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable); + release->store(false, cuda::memory_order_release); + + std::vector timer; + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + timer.emplace_back(); + } // initialize gemm underlying operation to handle parallel reduction library::Operation const * underlying_operation = operation; @@ -1182,110 +1249,158 @@ Status GemmOperationProfiler::profile_cutlass_( Status status; - for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { + std::vector graphs; + graphs.resize(gemm_workspace_.size()); + std::vector graphExecs; + graphExecs.resize(gemm_workspace_.size()); - int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaStreamBeginCapture(gemm_workspace_[i].stream, cudaStreamCaptureModeGlobal); + // Halt execution until all GPUs are ready to precede. + // It allows the CPU to trigger the GPUs all start at the same time. + delay<<<1, 1, 0, gemm_workspace_[i].stream>>>(release); + for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { + int problem_idx = (iteration % gemm_workspace_[i].problem_count) * problem_.batch_count; - gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); - gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); - gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->batch_data(problem_idx); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->batch_data(problem_idx); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->batch_data(problem_idx); - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); - } + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->batch_data(problem_idx); + } - // Execute the CUTLASS operation - status = underlying_operation->run( - &gemm_workspace_.arguments, - host_workspace, - device_workspace); - - if (status != Status::kSuccess) { - return status; - } - - // Run parallel reduction kernel for parallel split_k_mode - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - status = reduction_op_->run( - &gemm_workspace_.reduction_arguments, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); + // Execute the CUTLASS operation + status = underlying_operation->run( + &gemm_workspace_[i].arguments, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); if (status != Status::kSuccess) { return status; } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + status = reduction_op_->run( + &gemm_workspace_[i].reduction_arguments, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); + + if (status != Status::kSuccess) { + return status; + } + } } + + // + // Initialize GPU timer + // + + timer[i].start(gemm_workspace_[i].stream, cudaEventRecordExternal); + + // + // Profiling loop + // + + int Iterations = options.profiling.iterations; + + int iteration = 0; + + for (; iteration < Iterations; ++iteration) { + // Iterate over copies of the problem in memory + int workspace_idx = options.profiling.warmup_iterations + iteration; + int problem_idx = (workspace_idx % gemm_workspace_[i].problem_count) * problem_.batch_count; + + gemm_workspace_[i].arguments.A = gemm_workspace_[i].A->batch_data(problem_idx); + gemm_workspace_[i].arguments.B = gemm_workspace_[i].B->batch_data(problem_idx); + gemm_workspace_[i].arguments.C = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].arguments.D = gemm_workspace_[i].Computed->batch_data(problem_idx); + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); + + gemm_workspace_[i].reduction_arguments.workspace = gemm_workspace_[i].device_workspace.data(); + gemm_workspace_[i].reduction_arguments.source = gemm_workspace_[i].C->batch_data(problem_idx); + gemm_workspace_[i].reduction_arguments.destination = gemm_workspace_[i].Computed->batch_data(problem_idx); + } + + status = underlying_operation->run( + &gemm_workspace_[i].arguments, + gemm_workspace_[i].host_workspace.data(), + gemm_workspace_[i].device_workspace.data(), + gemm_workspace_[i].stream); + + if (status != Status::kSuccess) { + return status; + } + + // Run parallel reduction kernel for parallel split_k_mode + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + status = reduction_op_->run( + &gemm_workspace_[i].reduction_arguments, + gemm_workspace_[i].reduction_host_workspace.data(), + nullptr, + gemm_workspace_[i].stream); + + if (status != Status::kSuccess) { + return status; + } + } + } + timer[i].stop(gemm_workspace_[i].stream, cudaEventRecordExternal); + cudaStreamEndCapture(gemm_workspace_[i].stream, &graphs[i]); + cudaGraphInstantiate(&graphExecs[i], graphs[i], nullptr, nullptr, 0); } - // - // Initialize GPU timer - // - - timer.start(); - - // - // Profiling loop - // - - int Iterations = options.profiling.iterations; - - int iteration = 0; - for (; iteration < Iterations; ++iteration) { - - // Iterate over copies of the problem in memory - int workspace_idx = options.profiling.warmup_iterations + iteration; - int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count; - - gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); - gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); - gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); - - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); - - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx); - } - - status = underlying_operation->run( - arguments, - host_workspace, - device_workspace); - - if (status != Status::kSuccess) { - return status; - } - - // Run parallel reduction kernel for parallel split_k_mode - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - status = reduction_op_->run( - &gemm_workspace_.reduction_arguments, - gemm_workspace_.reduction_host_workspace.data(), - nullptr); - - if (status != Status::kSuccess) { - return status; - } - } + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaGraphLaunch(graphExecs[i], gemm_workspace_[i].stream); } // // Wait for completion // - timer.stop_and_wait(); + release->store(true, cuda::memory_order_release); + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaStreamSynchronize(gemm_workspace_[i].stream); + } // // Update performance result // - runtime = timer.duration(iteration); + + result.runtime = 0; + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + result.runtime_vector[i] = timer[i].duration(options.profiling.iterations); + result.runtime += result.runtime_vector[i]; + } + result.runtime /= static_cast(gemm_workspace_.size()); + + cudaFreeHost(release); + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(i)); + cudaGraphExecDestroy(graphExecs[i]); + cudaGraphDestroy(graphs[i]); + } + + for (size_t i = 0; i < gemm_workspace_.size(); ++i) { + cudaSetDevice(options.device.device_id(gemm_workspace_.size() - i - 1)); + timer.pop_back(); + } return status; } diff --git a/tools/profiler/src/gpu_timer.cpp b/tools/profiler/src/gpu_timer.cpp index cf03db18..cd0e4df0 100644 --- a/tools/profiler/src/gpu_timer.cpp +++ b/tools/profiler/src/gpu_timer.cpp @@ -33,9 +33,11 @@ */ #include +#include #include "cutlass/profiler/gpu_timer.h" + namespace cutlass { namespace profiler { @@ -52,32 +54,39 @@ GpuTimer::GpuTimer() { } } +GpuTimer::GpuTimer(GpuTimer&& gpu_timer) noexcept { + memcpy(events, gpu_timer.events, sizeof(events)); + memset(gpu_timer.events, 0, sizeof(gpu_timer.events)); +} + GpuTimer::~GpuTimer() { - for (auto & event : events) { - cudaEventDestroy(event); + for (const auto & event : events) { + if (event != nullptr) { + cudaEventDestroy(event); + } } } -/// Records a start event in the stream -void GpuTimer::start(cudaStream_t stream) { - cudaError_t result = cudaEventRecord(events[0], stream); +/// Records a start event in the stream, the flag is for cudaEventRecordWithFlags +void GpuTimer::start(cudaStream_t stream, const unsigned int flag) { + cudaError_t result = cudaEventRecordWithFlags(events[0], stream, flag); if (result != cudaSuccess) { throw std::runtime_error("Failed to record start event."); } } -/// Records a stop event in the stream -void GpuTimer::stop(cudaStream_t stream) { -cudaError_t result = cudaEventRecord(events[1], stream); +/// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags +void GpuTimer::stop(cudaStream_t stream, const unsigned int flag) { +cudaError_t result = cudaEventRecordWithFlags(events[1], stream, flag); if (result != cudaSuccess) { throw std::runtime_error("Failed to record stop event."); } } -/// Records a stop event in the stream and synchronizes on the stream -void GpuTimer::stop_and_wait(cudaStream_t stream) { +/// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags +void GpuTimer::stop_and_wait(cudaStream_t stream, const unsigned int flag) { - stop(stream); + stop(stream, flag); cudaError_t result; if (stream) { diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index ce1ebb21..4d5c9d09 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -658,7 +658,7 @@ void OperationProfiler::save_workspace( /// Method to profile a CUTLASS Operation Status OperationProfiler::profile_cutlass_( - double &runtime, + PerformanceResult &result, Options const &options, library::Operation const *operation, void *arguments, @@ -726,7 +726,7 @@ Status OperationProfiler::profile_cutlass_( // Update performance result // - runtime = timer.duration(iteration); + result.runtime = timer.duration(iteration); return status; } diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index f1c1d7a7..59368e9b 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -307,12 +307,6 @@ void Options::Initialization::get_distribution( {0, 0} }; - // Initalize pnz values to a default value of 100% - dist.gaussian.pnz = 1.0; - dist.gaussian.pnzA = 1.0; - dist.gaussian.pnzB = 1.0; - dist.gaussian.pnzC = 1.0; - using KeyValueVector = std::vector >; KeyValueVector values; @@ -330,6 +324,25 @@ void Options::Initialization::get_distribution( ++it; } + // Default initialization + switch (dist.kind) { + case cutlass::Distribution::Uniform: + dist.set_uniform(-4/*min*/, 4/*max*/); + break; + case cutlass::Distribution::Gaussian: + dist.set_gaussian(0/*mean*/, 4/*stddev*/); + break; + case cutlass::Distribution::Identity: + dist.set_identity(); + break; + case cutlass::Distribution::Sequential: + dist.set_sequential(0/*start*/, 4/*delta*/); + break; + default: + dist.set_uniform(-4/*min*/, 4/*max*/); + return; + } + // Subsequent key-value pairs update the named field of the distribution struct. for (; it != values.end(); ++it) { // Integer scaling factor - if < 0, no integer rounding is performed. diff --git a/tools/profiler/src/performance_report.cpp b/tools/profiler/src/performance_report.cpp index 12855a2f..1d04f48f 100644 --- a/tools/profiler/src/performance_report.cpp +++ b/tools/profiler/src/performance_report.cpp @@ -337,7 +337,15 @@ std::ostream & PerformanceReport::print_csv_header_( << ",Bytes" << ",Flops" << ",Flops/Byte" - << ",Runtime" + << ",Runtime"; + + if (options_.device.devices.size() > 1) { + for (size_t i = 0; i < options_.device.devices.size(); i++) { + out << ",Runtime_" << i; + } + } + + out << ",GB/s" << ",GFLOPs" ; @@ -376,6 +384,16 @@ std::ostream & PerformanceReport::print_result_csv_( << "," << result.flops / result.bytes << "," << result.runtime; + if (options_.device.devices.size() > 1) { + if (result.runtime_vector.size() != options_.device.devices.size()) { + throw std::runtime_error("Runtime vector size mismatch"); + } + + for (const auto runtime : result.runtime_vector) { + out << "," << runtime; + } + } + if (result.good()) { out diff --git a/tools/profiler/src/rank_2k_operation_profiler.cu b/tools/profiler/src/rank_2k_operation_profiler.cu index df8ad40f..4b547a3e 100644 --- a/tools/profiler/src/rank_2k_operation_profiler.cu +++ b/tools/profiler/src/rank_2k_operation_profiler.cu @@ -733,7 +733,7 @@ bool Rank2KOperationProfiler::profile( rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &rank_k_workspace_.arguments, diff --git a/tools/profiler/src/rank_k_operation_profiler.cu b/tools/profiler/src/rank_k_operation_profiler.cu index 49fe54c8..52613b8e 100644 --- a/tools/profiler/src/rank_k_operation_profiler.cu +++ b/tools/profiler/src/rank_k_operation_profiler.cu @@ -718,7 +718,7 @@ bool RankKOperationProfiler::profile( rank_k_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &rank_k_workspace_.arguments, diff --git a/tools/profiler/src/sparse_gemm_operation_profiler.cu b/tools/profiler/src/sparse_gemm_operation_profiler.cu index 939608f8..ec14a332 100644 --- a/tools/profiler/src/sparse_gemm_operation_profiler.cu +++ b/tools/profiler/src/sparse_gemm_operation_profiler.cu @@ -578,7 +578,7 @@ bool SparseGemmOperationProfiler::profile( gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &gemm_workspace_.arguments, diff --git a/tools/profiler/src/symm_operation_profiler.cu b/tools/profiler/src/symm_operation_profiler.cu index 59fcf0f1..80f645e7 100644 --- a/tools/profiler/src/symm_operation_profiler.cu +++ b/tools/profiler/src/symm_operation_profiler.cu @@ -771,7 +771,7 @@ bool SymmOperationProfiler::profile( symm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &symm_workspace_.arguments, diff --git a/tools/profiler/src/trmm_operation_profiler.cu b/tools/profiler/src/trmm_operation_profiler.cu index 5983b011..9d3b4db6 100644 --- a/tools/profiler/src/trmm_operation_profiler.cu +++ b/tools/profiler/src/trmm_operation_profiler.cu @@ -709,7 +709,7 @@ bool TrmmOperationProfiler::profile( trmm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; results_.back().status = profile_cutlass_( - results_.back().runtime, + results_.back(), options, operation, &trmm_workspace_.arguments, diff --git a/tools/util/include/cutlass/util/device_dump.h b/tools/util/include/cutlass/util/device_dump.h index cd016160..bb20e9b7 100644 --- a/tools/util/include/cutlass/util/device_dump.h +++ b/tools/util/include/cutlass/util/device_dump.h @@ -31,7 +31,7 @@ #pragma once -#include +#include #include "cutlass/cutlass.h" /** diff --git a/tools/util/include/cutlass/util/device_groupnorm.h b/tools/util/include/cutlass/util/device_groupnorm.h index 07f56c71..5fc93a11 100644 --- a/tools/util/include/cutlass/util/device_groupnorm.h +++ b/tools/util/include/cutlass/util/device_groupnorm.h @@ -42,7 +42,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "device_utils.h" -#include +#include namespace cutlass { diff --git a/tools/util/include/cutlass/util/device_layernorm.h b/tools/util/include/cutlass/util/device_layernorm.h index 0ee58a7f..7708c3eb 100644 --- a/tools/util/include/cutlass/util/device_layernorm.h +++ b/tools/util/include/cutlass/util/device_layernorm.h @@ -42,7 +42,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "device_utils.h" -#include +#include namespace cutlass { diff --git a/tools/util/include/cutlass/util/device_nhwc_pooling.h b/tools/util/include/cutlass/util/device_nhwc_pooling.h index 05fe5584..cce452d9 100644 --- a/tools/util/include/cutlass/util/device_nhwc_pooling.h +++ b/tools/util/include/cutlass/util/device_nhwc_pooling.h @@ -42,7 +42,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "device_utils.h" -#include +#include namespace cutlass { diff --git a/tools/util/include/cutlass/util/device_rmsnorm.h b/tools/util/include/cutlass/util/device_rmsnorm.h index c4542eff..44a1c084 100644 --- a/tools/util/include/cutlass/util/device_rmsnorm.h +++ b/tools/util/include/cutlass/util/device_rmsnorm.h @@ -37,7 +37,7 @@ #include "cutlass/tensor_coord.h" #include "cutlass/tensor_ref.h" #include "cutlass/util/device_utils.h" -#include +#include namespace cutlass { @@ -165,12 +165,12 @@ void rmsnorm(cutlass::MatrixCoord tensor_size, dim3 grid(m); if (n % 8 == 0 && std::is_same::value) { - dim3 block(min(1024, (n / 8 + 31) / 32 * 32)); + dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32)); rmsnorm_twoPassAlgo_e8<<>>( (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); } else { - dim3 block(min(1024, ((n + 31)/32 + 31)/32*32)); + dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32)); rmsnorm_twoPassAlgo_e1<<>>( output, input, weight, m, n, epsilon); diff --git a/tools/util/include/cutlass/util/device_utils.h b/tools/util/include/cutlass/util/device_utils.h index 3ec078c8..7a8378fc 100644 --- a/tools/util/include/cutlass/util/device_utils.h +++ b/tools/util/include/cutlass/util/device_utils.h @@ -36,7 +36,7 @@ #pragma once #include -#include +#include #define FINAL_MASK 0xffffffff struct half4 { diff --git a/tools/util/include/cutlass/util/distribution.h b/tools/util/include/cutlass/util/distribution.h index 649a5736..086e033a 100644 --- a/tools/util/include/cutlass/util/distribution.h +++ b/tools/util/include/cutlass/util/distribution.h @@ -100,6 +100,9 @@ struct Distribution { gaussian.mean = _mean; gaussian.stddev = _stddev; gaussian.pnz = _pnz; + gaussian.pnzA = _pnz; + gaussian.pnzB = _pnz; + gaussian.pnzC = _pnz; int_scale = _int_scale; return *this; } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 13aedf14..059076d9 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -138,8 +138,8 @@ struct RandomGaussianFunc { int_scale(int_scale_), exclude_zero(exclude_zero_) { - float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits - float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -175,8 +175,8 @@ struct RandomGaussianFunc { Element result; if (params.int_scale >= 0) { - rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up))); - result = Element(IntType(rnd * params.float_scale_down)); + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); } else { result = Element(rnd); @@ -237,7 +237,6 @@ struct RandomGaussianFunc> { exclude_zero(exclude_zero_) { float_scale_up = FloatType(IntType(1) << int_scale); - float_scale_up += FloatType(0.5) * float_scale_up; float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -276,8 +275,8 @@ struct RandomGaussianFunc> { Element result; if (params.int_scale >= 0) { - rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); - rnd_i = FloatType(IntType(rnd_i * params.float_scale_down)); + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); result = { Real(rnd_r * params.float_scale_down), @@ -482,8 +481,8 @@ struct RandomUniformFunc { pnan(pnan_), exclude_zero(exclude_zero_) { - float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits - float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); // Handle cases where min = 0 or max = 0 for excluding zeros if (exclude_zero >= 0) { @@ -535,8 +534,8 @@ struct RandomUniformFunc { Element result; if (params.int_scale >= 0) { - rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up))); - result = Element(IntType(rnd * params.float_scale_down)); + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); } else { result = Element(rnd); @@ -612,7 +611,6 @@ struct RandomUniformFunc> { exclude_zero(exclude_zero_) { float_scale_up = FloatType(IntType(1) << int_scale); - float_scale_up += FloatType(0.5) * float_scale_up; float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); // Handle cases where min = 0 or max = 0 for excluding zeros @@ -668,8 +666,8 @@ struct RandomUniformFunc> { Element result; if (params.int_scale >= 0) { - rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); - rnd_i = FloatType(IntType(rnd_i * params.float_scale_up)); + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); result = { Real(rnd_r * params.float_scale_down), diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index f6984fb2..184d7737 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -281,7 +281,6 @@ void gett_epilogue( cute::is_same_v; constexpr bool IsClamp = cute::is_same_v>; - constexpr bool IsBackpropFusion = cute::is_same_v> or cute::is_same_v>; diff --git a/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h index 090019c1..9e1ac76c 100644 --- a/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h +++ b/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h @@ -41,7 +41,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" -#include +#include namespace cutlass { namespace reference { diff --git a/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/tools/util/include/cutlass/util/reference/host/rank_k_complex.h index ef44270a..6f9d5dc4 100644 --- a/tools/util/include/cutlass/util/reference/host/rank_k_complex.h +++ b/tools/util/include/cutlass/util/reference/host/rank_k_complex.h @@ -41,7 +41,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" -#include +#include namespace cutlass { namespace reference { diff --git a/tools/util/include/cutlass/util/reference/host/symm_complex.h b/tools/util/include/cutlass/util/reference/host/symm_complex.h index 2618feaa..7a55bb39 100644 --- a/tools/util/include/cutlass/util/reference/host/symm_complex.h +++ b/tools/util/include/cutlass/util/reference/host/symm_complex.h @@ -41,7 +41,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" -#include +#include namespace cutlass { namespace reference { diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index b9f0c84d..85c70e41 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -269,8 +269,8 @@ struct RandomGaussianFunc > { // Sample from the Gaussian distribution for a nonzero element if (bernoulli_result) { if (int_scale >= 0) { - rnd[0] = double(int(rnd[0] * double(1 << int_scale))); - rnd[1] = double(int(rnd[1] * double(1 << int_scale))); + rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale))); + rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale))); reals[0] = from_real(rnd[0] / double(1 << int_scale)); reals[1] = from_real(rnd[1] / double(1 << int_scale)); } @@ -348,10 +348,10 @@ struct RandomGaussianFunc > { // Sample from the Gaussian distribution for a nonzero element if (bernoulli_result) { if (int_scale >= 0) { - rnd1[0] = double(int(rnd1[0] * double(1 << int_scale))); - rnd1[1] = double(int(rnd1[1] * double(1 << int_scale))); - rnd2[0] = double(int(rnd2[0] * double(1 << int_scale))); - rnd2[1] = double(int(rnd2[1] * double(1 << int_scale))); + rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale))); reals[0] = from_real(rnd1[0] / double(1 << int_scale)); reals[1] = from_real(rnd1[1] / double(1 << int_scale)); @@ -725,7 +725,7 @@ public: // testing if (int_scale >= 0) { - rnd = double(int(rnd * double(1 << int_scale))); + rnd = double(std::llround(rnd * double(1 << int_scale))); reals[i] = from_real(Real(rnd / double(1 << int_scale))); } else { @@ -808,7 +808,7 @@ public: // testing if (int_scale >= 0) { - rnd = double(int(rnd * double(1 << int_scale))); + rnd = double(std::llround(rnd * double(1 << int_scale))); reals[i] = from_real(Real(rnd / double(1 << int_scale))); } else { diff --git a/tools/util/include/cutlass/util/type_traits.h b/tools/util/include/cutlass/util/type_traits.h index 8379957a..dec3168e 100644 --- a/tools/util/include/cutlass/util/type_traits.h +++ b/tools/util/include/cutlass/util/type_traits.h @@ -36,7 +36,7 @@ #include #include -#include +#include #include "cutlass/numeric_types.h" #include "cutlass/complex.h"