diff --git a/CHANGELOG.md b/CHANGELOG.md index 87f01574..8b9f0afc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,12 @@ # NVIDIA CUTLASS Changelog +## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4) (2023-12-29) +* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors. +* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) +* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above). +* Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above). +* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released. +* Improved [CuTe TMA Tensor](/media/docs/cute/0z_tma_tensors.md) documentation. + ## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31) * [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. diff --git a/CMakeLists.txt b/CMakeLists.txt index ec5dc5f2..0fa6feaa 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") -project(CUTLASS VERSION 3.3.0 LANGUAGES CXX) +project(CUTLASS VERSION 3.4.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -681,6 +681,12 @@ endif() ################################################################################ +set(CUTLASS_DEFAULT_ACTIVE_TEST_SETS "default" CACHE STRING "Default + activated test sets. In `make test` mode, this string determines the + active set of tests. In `ctest` mode, this value can be overriden + with CUTLASS_TEST_SETS environment variable when running the ctest + executable.") + set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake) set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "") @@ -701,11 +707,12 @@ function(cutlass_add_executable_tests NAME TARGET) # generating the full variable name to be referenced. # RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed # test results to speed up test runtime. +# TEST_SETS_SUPPORTED: A list of test set names these tests support. # set(options DISABLE_EXECUTABLE_INSTALL_RULE) set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX) - set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS) + set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if (NOT DEFINED __DISABLE_TESTS) @@ -715,6 +722,12 @@ function(cutlass_add_executable_tests NAME TARGET) set(TEST_EXE $) set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR}) + if (NOT DEFINED __TEST_SETS_SUPPORTED) + set(__TEST_SETS_SUPPORTED ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS}) + endif() + + set(TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED}) + if (__RESULT_CACHE_FILE) add_custom_command( @@ -816,8 +829,6 @@ function(cutlass_add_executable_tests NAME TARGET) set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) file(MAKE_DIRECTORY ${TEST_GEN_DIR}) - set(TEST_SETS_SUPPORTED default) - set(TEST_EXE_PATH $) set(TEST_USE_EXTENDED_FORMAT ON) configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) @@ -873,9 +884,9 @@ if (CUTLASS_INSTALL_TESTS) file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest") file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n") - - file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SET})\n") - file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SET} \"default\")\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "cmake_policy(SET CMP0057 NEW) # Allow IN_LIST for if()\n\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SETS})\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SETS} ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})\n") file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n") foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES}) @@ -897,9 +908,15 @@ write_basic_package_version_file( ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake COMPATIBILITY AnyNewerVersion) +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake + @ONLY + ) + install( FILES - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/ ) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 5a159d8c..91537ea7 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -13,6 +13,7 @@ Cris Cecka
Aniket Shivam
Jack Kosaian
Mark Hoemmen
+Richard Cai
Honghao Lu
Ethan Yan
Haicheng Wu
@@ -21,6 +22,8 @@ Dustyn Blasig
Fengqi Qiao
Duane Merrill
Yujia Zhai
+Rawn Henry
+Sergey Klevtsov
Shang Zhang
Piotr Majcher
Paul Springer
@@ -55,6 +58,7 @@ Alan Kaatz
Tina Li
Timmy Liu
Wei Liu
+Tim Martin
Duane Merrill
Kevin Siu
Markus Tavenrath
diff --git a/CUDA.cmake b/CUDA.cmake index b9c60bcd..eb1ab9ad 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -248,11 +248,15 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR) message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required") endif() + if (NOT DEFINED __BATCH_SOURCES) + set(__BATCH_SOURCES ON) + endif() + if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE) set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE}) endif() - if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1) + if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1) set(CUDA_FILE_ARGS) set(TARGET_SOURCE_ARGS) diff --git a/README.md b/README.md index 4c43f1b9..53688efa 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.3 +# CUTLASS 3.4 -_CUTLASS 3.3 - October 2023_ +_CUTLASS 3.4 - December 2023_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -41,17 +41,14 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -# What's New in CUTLASS 3.3 +# What's New in CUTLASS 3.4 -CUTLASS 3.3.0 is an update to CUTLASS adding: +CUTLASS 3.4.0 is an update to CUTLASS adding: -- New [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input types with optimal performance. -- New [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8} and upcast on operandA {s8, u8} x {fp16, bf16}. They also include fast numeric conversion recipes and warp level shuffles to achieve optimal performance. -- New [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors (across s8/fp8/fp16/bf16/tf32 types) with optimal performance. As a part of this, new kernel schedules, and Copy Ops [SM80\_CP\_ASYNC\_CACHE\_\*](/include/cute/arch/copy_sm80.hpp) were also added. -- EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details. -- Various subbyte enhancements like tagged device ptrs, support for vectorized copy, various operators to treat subbyte iterators as pointers, and full-fledged CuTe Tensor support. -- Support for Clang as a host compiler. -- Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface +- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100. +- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above) +- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above) +- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library. Minimum requirements: @@ -95,7 +92,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA CUTLASS requires a C++17 host compiler and performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive). -It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1. +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1 ## Operating Systems We have tested the following environments. @@ -107,6 +104,7 @@ We have tested the following environments. | Ubuntu 22.04 | GCC 11.2.0 | | Ubuntu 22.04 | Clang 10.0.0 | | Ubuntu 22.04 | Clang 14.0.6 | +| Ubuntu 22.04 | Clang 17.0.6 | | Windows 10.0 | Visual Studio 2019 v16.11.27 | Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. diff --git a/cmake/CTestTestfile.configure.cmake b/cmake/CTestTestfile.configure.cmake index 524ba1f8..2e1e50d8 100644 --- a/cmake/CTestTestfile.configure.cmake +++ b/cmake/CTestTestfile.configure.cmake @@ -2,10 +2,16 @@ set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@) -#? if (DEFINED ENV{CUTLASS_TEST_SET} AND NOT ENV{CUTLASS_TEST_SET} IN_LIST TEST_SETS_SUPPORTED) -#? message(STATUS "Skipping tests for @TEST_EXE_PATH@ as $ENV{CUTLASS_TEST_SET} is not in the set of ${TEST_SETS_SUPPORTED}.") -#? return() -#? endif() +if (NOT DEFINED ENV{CUTLASS_TEST_SETS}) + set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@) +endif() + +foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS}) + if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED) + message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].") + return() + endif() +endforeach() set(TEST_EXE_PATH @TEST_EXE_PATH@) set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@) diff --git a/cmake/NvidiaCutlassConfig.cmake b/cmake/NvidiaCutlassConfig.cmake.in similarity index 71% rename from cmake/NvidiaCutlassConfig.cmake rename to cmake/NvidiaCutlassConfig.cmake.in index 56d1c450..2fe69119 100644 --- a/cmake/NvidiaCutlassConfig.cmake +++ b/cmake/NvidiaCutlassConfig.cmake.in @@ -7,6 +7,3 @@ if(TARGET nvidia::cutlass::CUTLASS) endif() include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake") - -# For backward compatibility with the old name -add_library(cutlass_lib ALIAS cutlass_library) diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu index c5498adf..0b1ad355 100644 --- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu +++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu @@ -291,8 +291,8 @@ int run() { LayoutInputB, ElementOutput, LayoutOutput, - int32_t, - int32_t> + ElementComputeEpilogue, + ElementComputeEpilogue> gemm_device; // Launch device reference gemm kernel diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu index 563c2a0d..34bde0dc 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu +++ b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu @@ -279,7 +279,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "28_ampere_gemm_bias_fusion example\n\n" + out << "23_ampere_operand_gemm_reduction_fusion\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --m= GEMM M\n" @@ -297,7 +297,7 @@ struct Options { << " --tag= String to replicate across the first column in the results table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/23_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n"; + << "$ ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --m=1024 --n=1024 --k=1024 \n\n"; return out; } diff --git a/examples/30_wgrad_split_k/30_wgrad_split_k.cu b/examples/30_wgrad_split_k/30_wgrad_split_k.cu index e512242e..af7c6949 100644 --- a/examples/30_wgrad_split_k/30_wgrad_split_k.cu +++ b/examples/30_wgrad_split_k/30_wgrad_split_k.cu @@ -602,7 +602,7 @@ Result profile_convolution(Options const &options) { std::stringstream ss; - ss << "26_ampere_fused_wgrad_batch_normalization_" + ss << "30_wgrad_split_k_" << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() diff --git a/examples/34_transposed_conv2d/34_transposed_conv2d.cu b/examples/34_transposed_conv2d/34_transposed_conv2d.cu index 2e4ce3c0..668cd93a 100644 --- a/examples/34_transposed_conv2d/34_transposed_conv2d.cu +++ b/examples/34_transposed_conv2d/34_transposed_conv2d.cu @@ -251,7 +251,7 @@ struct Options { << " --tag= String to replicate across the first column in the results table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/31_transposed_conv2d/31_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n"; + << "$ ./examples/34_transposed_conv2d/34_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n"; return out; } diff --git a/examples/38_syr2k_grouped/syr2k_grouped.cu b/examples/38_syr2k_grouped/syr2k_grouped.cu index d8adb9c1..bebfd8cc 100644 --- a/examples/38_syr2k_grouped/syr2k_grouped.cu +++ b/examples/38_syr2k_grouped/syr2k_grouped.cu @@ -398,7 +398,7 @@ struct Options { << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n" << "# Execute Grouped SYR2K and profile with NSight\n" - << "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + << "$ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; return out; } diff --git a/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu index 36bf775c..3f351a46 100644 --- a/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu +++ b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu @@ -306,7 +306,7 @@ struct Options { /// Prints the usage statement. std::ostream &print_usage(std::ostream &out) const { - out << "41_depthwise_gemm_fprop example\n\n" + out << "46_depthwise_gemm_fprop example\n\n" << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" << " forward convolution on tensors of layout NHWC.\n\n" << "Options:\n\n" @@ -554,7 +554,7 @@ Result profile_convolution(Options const &options) { if (options.save_workspace) { std::stringstream ss; - ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h() + ss << "46_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() << ".dat"; diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index d0e47939..ba1407c5 100644 --- a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -291,7 +291,7 @@ struct ExampleRunner { using CustomEVT = // alpha * acc + beta * C cutlass::epilogue::fusion::Sm90EVT, // beta * C + (alpha * acc) cutlass::epilogue::fusion::Sm90ScalarBroadcast, // beta - cutlass::epilogue::fusion::Sm90SrcFetch, // C + cutlass::epilogue::fusion::Sm90SrcFetch, // C cutlass::epilogue::fusion::Sm90EVT, // alpha * acc cutlass::epilogue::fusion::Sm90ScalarBroadcast, // alpha cutlass::epilogue::fusion::Sm90AccFetch // acc @@ -302,7 +302,7 @@ struct ExampleRunner { // Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp // to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts. // These tags also provide additional metadata that can be queried at compile time. - using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index b962e3dc..de91a8cc 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -568,7 +568,7 @@ struct ExampleRunner if (options.reference_check) { if (!verify()) { std::cout << "Failed validation" << std::endl; -#if 1 +#if 0 debug_output(std::cout); #endif return false; diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp index 067b1dce..26a23d57 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -122,7 +122,7 @@ public: static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same."); static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup; static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); diff --git a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu index aa93304b..3515500c 100644 --- a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu +++ b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu @@ -109,6 +109,8 @@ namespace example { +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + struct Options { bool help; @@ -724,6 +726,7 @@ private: return true; } }; +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) } // namespace example @@ -749,7 +752,7 @@ int main(int argc, char const **argv) if (notSupported) { return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems } - +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) example::Options options; options.parse(argc, argv); @@ -970,6 +973,6 @@ int main(int argc, char const **argv) result &= runner.run(options); } #endif - return result ? EXIT_SUCCESS : EXIT_FAILURE; +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) } diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu index 080d7034..b3ad7e36 100644 --- a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu +++ b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu @@ -109,9 +109,11 @@ using ElementD = ElementC; using LayoutD = LayoutC; constexpr int AlignmentD = AlignmentC; -// Auxiliary matrix configuration +// Auxiliary matrix configuration and other fusion types using ElementAux = ElementC; using LayoutAux = LayoutC; +using ElementAmax = float; +using ElementBias = float; // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation @@ -124,7 +126,7 @@ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< - LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux>; + LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index 92c2207c..5fd2fe51 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -38,19 +38,41 @@ using INT8 tensor cores. The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap - A and B in the main loop. Consequently, it is essential to consider this when constructing the epilogue, as illustrated in this example. + A and B in the main loop. However, implicit swaps do not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales should have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the groups size + equal to the gemm problem K. Limitations: 1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}. 2) The narrow type must always be in K-major format. - 3) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format. - 4) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the - operands to ensure the narrow type passes through the register file, and TMA epilogues do not currently support swap + transpose operations. - We plan to address this limitation in the future. + 3) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 4) The scales and the zeros must have the same layout and groupsize. + 5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format. + 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 - $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2 + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 + + Runs the mixed input gemm, and applies a scaling factor and adds a zero-point to B before mma (mode 2). Uses a group size of 128. + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=128 --mode=2 */ #include @@ -79,20 +101,28 @@ #include "cutlass/util/reference/host/gett.hpp" #include "helper.h" -#include "unfused_weight_dequantize.h" +#include "unfused_weight_dequantize.hpp" using namespace cute; #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +// This is just an example, so we use a regular enum so we can compare directly to the command-line int. +enum GemmMode { + ConvertOnly, + ScaleOnly, + ScaleWithZeroPoint +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// -using MmaType = cutlass::half_t; -using QuantType = int8_t; +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; // A matrix configuration -using ElementA = MmaType; // Element type for A matrix operand +using ElementA = MmaType; // Element type for A matrix operand using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) @@ -101,6 +131,14 @@ using ElementB = QuantType; // E using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using ElementZero = cutlass::half_t; +using ElementScale = cutlass::half_t; +using LayoutScale = cutlass::layout::RowMajor; + // C/D matrix configuration using ElementC = cutlass::half_t; // Element type for C and D matrix operands using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands @@ -109,51 +147,107 @@ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // M // D matrix configuration using ElementD = ElementC; using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Core kernel configurations using ElementAccumulator = float; // Element type for internal accumulation using ElementCompute = float; // Element type for epilogue computation using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_128,_256,_64>; // Threadblock-level tile size +using TileShape = Shape<_128,_256,cute::Int>; // Threadblock-level tile size using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster -using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size -using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder - - -using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, + EpilogueTileType, ElementAccumulator, ElementAccumulator, - // Lie here about layout of C and D since we do swap and transpose trick + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, - ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, - cutlass::epilogue::NoSmemWarpSpecialized // This is the only epi supporting the required swap + transpose. + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. >::CollectiveOp; -using GemmKernel = cutlass::gemm::kernel::GemmUniversal< +// ============================================================ MIXED INPUT NO SCALES ============================================================================ +// The collective will infer that the narrow type should be upcasted to the wide type. +// We swap A and B operands to the builder here +using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementB, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal< Shape, // Indicates ProblemShape - CollectiveMainloop, + CollectiveMainloopConvertOnly, CollectiveEpilogue >; -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter; -using StrideA = typename Gemm::GemmKernel::StrideA; -using StrideB = typename Gemm::GemmKernel::StrideB; -using StrideC = typename Gemm::GemmKernel::StrideC; -using StrideD = typename Gemm::GemmKernel::StrideD; +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ================================================================== +// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format. +using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleWithZeroPoint, + CollectiveEpilogue +>; + +using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter; +// ================================================================================================================================================================= + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; +using StrideC = typename GemmKernelScaleWithZeroPoint::StrideC; +using StrideD = typename GemmKernelScaleWithZeroPoint::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; // // Data members @@ -163,20 +257,25 @@ using StrideD = typename Gemm::GemmKernel::StrideD; StrideA stride_A; StrideB stride_B; StrideC stride_C; +StrideC_ref stride_C_ref; StrideD stride_D; +StrideD_ref stride_D_ref; uint64_t seed; -// Initialization functions don't handle sub-byte types so we use uint8 to initialize and a separate -// kernel to pack the data if it is necessary. -using InitializationType = cute::conditional_t < 8, uint8_t, QuantType>; +// Scale and Zero share a stride since the layout and shapes must be the same. +using StrideS = typename CollectiveMainloopScaleWithZeroPoint::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; -cutlass::HostTensor tensor_A; -cutlass::HostTensor tensor_B_init; -cutlass::HostTensor tensor_B; -cutlass::HostTensor tensor_B_dq; -cutlass::HostTensor tensor_C; -cutlass::HostTensor tensor_D; -cutlass::HostTensor tensor_ref_D; +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_B_dq; +cutlass::HostTensor tensor_scale; +cutlass::HostTensor tensor_zero; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) @@ -192,7 +291,9 @@ struct Options { float alpha = 1.0f; float beta = 0.0f; int iterations = 1000; + int mode = 2; int m = 5120, n = 4096, k = 4096; + int g = 128; int l = 1; // Parses the command line @@ -208,6 +309,8 @@ struct Options { cmd.get_cmd_line_argument("n", n); cmd.get_cmd_line_argument("k", k); cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("mode", mode); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); @@ -224,13 +327,15 @@ struct Options { << " --n= Sets the N extent of the GEMM\n" << " --k= Sets the K extent of the GEMM\n" << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" << " --iterations= Number of profiling iterations to perform.\n\n"; out << "\n\nExamples:\n\n" - << "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n"; + << "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; return out; } @@ -289,114 +394,146 @@ bool initialize_tensor( scope_min = -8; } cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); + view, seed, scope_max, scope_min); return true; } -template +template bool initialize_quant_tensor( cutlass::TensorView view, uint64_t seed=2023) { - Element scope_max, scope_min; - constexpr int bits_input = cute::sizeof_bits_v; - static_assert(bits_input <= 8, "Quantization type can be at most 8 bits"); - - if constexpr (bits_input == 8) { - // Directly init 1-byte types - static_assert(cute::is_same_v, "Init type should equal quant type for 1 byte types"); - scope_max = std::numeric_limits::max(); - scope_min = std::numeric_limits::min(); - } else { - static_assert(cute::is_same_v, "Init type should be uint8_t for sub-byte types"); - scope_max = (1 << bits_input); - scope_min = 0; - } + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); cutlass::reference::host::TensorFillRandomUniform( - view, seed, scope_max, scope_min, 0); + view, seed, scope_max, scope_min); return true; } template -bool initialize_with_one( - cutlass::TensorView view) { - cutlass::reference::host::TensorFill(view, Element(1.0f)); +bool initialize_scale( + cutlass::TensorView view, + const Options &options) { + + if (options.mode == GemmMode::ConvertOnly) { + // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. + cutlass::reference::host::TensorFill(view, Element(1.0f)); + } + else { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + const float max_dequant_val = 4.f; + const float min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min); + } return true; } -template -void prepare_packed_data(cutlass::HostTensor view_dst_data, - cutlass::HostTensor view_src_data, - const L& cute_layout) { - if constexpr (cute::is_same_v) { - view_dst_data.copy_in_device_to_device(view_src_data.device_data()); - } - else { - pack_data(view_dst_data.device_data(), view_src_data.device_data(), cute_layout); +template +bool initialize_zero( + cutlass::TensorView view, + const Options &options) { + + if (options.mode == GemmMode::ScaleWithZeroPoint) { + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2.0f, -2.0f); + } else { + // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. + cutlass::reference::host::TensorFill(view, Element(0.0f)); } + return true; } /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const Options &options) { auto shape_b = cute::make_shape(options.n, options.k, options.l); + const int scale_k = (options.k + options.g - 1) / options.g; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); + // Reverse stride here due to swap and transpose stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); tensor_A.resize(a_coord); - tensor_B_init.resize(b_coord); tensor_B.resize(b_coord); tensor_B_dq.resize(b_coord); tensor_C.resize(c_coord); tensor_D.resize(c_coord); tensor_ref_D.resize(c_coord); - // We need scales since the "dequantize" kernels expects them. We just set them to 1 so the values get converted - // to the mma type. - cutlass::HostTensor tensor_scale; - tensor_scale.resize({1 * options.l, options.n}); + tensor_scale.resize({scale_k * options.l, options.n}); + tensor_zero.resize({scale_k * options.l, options.n}); initialize_tensor(tensor_A.host_view(), seed + 2022); - initialize_quant_tensor(tensor_B_init.host_view(), seed + 2021); + initialize_quant_tensor(tensor_B.host_view(), seed + 2021); initialize_tensor(tensor_C.host_view(), seed + 2020); - initialize_with_one(tensor_scale.host_view()); + initialize_scale(tensor_scale.host_view(), options); + initialize_zero(tensor_zero.host_view(), options); tensor_A.sync_device(); - tensor_B_init.sync_device(); + tensor_B.sync_device(); tensor_C.sync_device(); tensor_scale.sync_device(); + tensor_zero.sync_device(); auto layout_B = make_layout(shape_b, stride_B); - prepare_packed_data(tensor_B, tensor_B_init, layout_B); - auto shape_scale = cute::make_shape(options.n, 1, options.l); - auto layout_scale = make_layout(shape_scale); - dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), layout_scale); + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); - tensor_B.sync_host(); + dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), tensor_zero.device_data(), layout_scale_zero, options.g); tensor_B_dq.sync_host(); } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) +template +Args args_from_options(const Options &options) { - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - {options.m, options.n, options.k, options.l}, - {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, - {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} - }; - - return arguments; +// Swap the A and B tensors, as well as problem shapes here. + if (options.mode == GemmMode::ConvertOnly) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A}, + {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + }; + } + else if (options.mode == GemmMode::ScaleOnly) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g}, + {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + }; + } + else if (options.mode == GemmMode::ScaleWithZeroPoint) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {tensor_B.device_data(), stride_B, tensor_A.device_data(), stride_A, tensor_scale.device_data(), stride_S, options.g, tensor_zero.device_data()}, + {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + }; + } else { + std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; + exit(-1); + } } bool verify(const Options &options) { @@ -404,44 +541,64 @@ bool verify(const Options &options) { // Compute reference output // - // Create instantiation for device reference gemm kernel - auto A = cute::make_tensor(tensor_A.host_data(), - cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); - auto B = cute::make_tensor(tensor_B_dq.host_data(), - cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); - auto C = cute::make_tensor(tensor_C.host_data(), - cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); - auto D = cute::make_tensor(tensor_ref_D.host_data(), - cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + // In this example, we use the GPU default kernels as a reference (unfused scale) + // This is to avoid numerical differences from different accumulation order. - using unused_t = decltype(D); + // Again, due to numerical differences, we must use fast acc here when the mma type is + // FP8 as the fused implementation only supports fast acc at the moment. + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; - cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; - - cutlass::reference::host::GettEpilogueParams< - typename Gemm::EpilogueOutputOp::ElementScalar, - typename Gemm::EpilogueOutputOp::ElementScalar, + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, ElementAccumulator, - ElementCompute, - decltype(C), - decltype(D), - unused_t, // bias - unused_t, // aux - unused_t, // valpha - unused_t // vbeta - > epilogue_params; + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; - epilogue_params.C = C; - epilogue_params.D = D; - epilogue_params.alpha = options.alpha; - epilogue_params.beta = options.beta; + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; - // get reference result - cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B_dq.device_data(), stride_B}, + {{options.alpha, options.beta}, tensor_C.device_data(), stride_C_ref, tensor_ref_D.device_data(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); // compare_reference tensor_D.sync_host(); - bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + tensor_ref_D.sync_host(); + const ElementD epsilon(1e-2f); + const ElementD non_zero_floor(1e-4f); + bool passed = cutlass::reference::host::TensorRelativelyEquals(tensor_ref_D.host_view(), tensor_D.host_view(), epsilon, non_zero_floor); return passed; } @@ -455,7 +612,7 @@ int run(Options &options) Gemm gemm; // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm - auto arguments = args_from_options(options); + auto arguments = args_from_options(options); // Using the arguments, query for extra workspace required for matrix multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -549,7 +706,26 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - run(options); + if (options.mode == GemmMode::ConvertOnly) { + std::cout << "Running in no scale mode." << std::endl; + run(options); + } + else if (options.mode == GemmMode::ScaleOnly) { + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + run(options); + } + else if (options.mode == GemmMode::ScaleWithZeroPoint) { + if (options.g == options.k) { + std::cout << "Running in per-column scale and zero mode." << std::endl; + } else { + std::cout << "Running in group scale and zero mode." << std::endl; + } + run(options); + } #endif return 0; diff --git a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt index 8b8ac5ba..7fc05336 100644 --- a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt +++ b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt @@ -27,9 +27,33 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + +set(TEST_DIRECT_BATCHED --m=2048 --n=2048 --k=2048 --l=2 --mode=0 --iterations=0) # Direct conversion + +set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 --iterations=0) # Per Column scaling +set(TEST_SCALE_ZERO_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=2 --iterations=0) # Per Column scaling + +set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --g=512 --mode=1 --iterations=0) # Group-wise scaling +set(TEST_SCALE_ZERO_GROUPED --m=2048 --n=5120 --k=8192 --g=256 --mode=2 --iterations=0) # Group-wise scaling with zero-point + +set(TEST_SCALE_RESIDUE --m=128 --n=128 --k=320 --g=128 --mode=1 --iterations=0) # Final group has residue +set(TEST_SCALE_ZERO_RESIDUE --m=128 --n=128 --k=192 --g=128 --mode=2 --iterations=0) # Final group has residue + +set(TEST_ALPHA_BETA --alpha=0.5 --beta=0.7 --mode=2 --iterations=0) # Alpha and Beta with default shapes cutlass_example_add_executable( 55_hopper_mixed_dtype_gemm 55_hopper_mixed_dtype_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_ZERO_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_ZERO_GROUPED + TEST_SCALE_RESIDUE + TEST_SCALE_ZERO_RESIDUE + TEST_ALPHA_BETA ) diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md index 2c023651..8c393a6b 100644 --- a/examples/55_hopper_mixed_dtype_gemm/README.md +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -9,17 +9,14 @@ This first version only supports mixed type GEMMs using TMA. ## Performance -While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x int8` for problems that are compute bound. +While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16, bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type. We are currently optimizing the following cases: 1. Memory bound cases for all types -1. Compute bound cases for `{16-bit, 8-bit} x {4-bit, 2-bit}` - -As a result, we do not suggest using this example as a benchmarking reference until all of our optimizations are complete (this will be clearly stated in this README in a future release). ## Limitations -* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen to control the layout of the epilogue as shown in the example. Note that TMA epilogues currently do not support swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release. +* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen. Note that TMA epilogues currently do not support *implicit* swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release. * The layout of the narrow type must be K-major. This means the following: * Narrow type is the A operand: Must be Row-Major @@ -29,8 +26,12 @@ As a result, we do not suggest using this example as a benchmarking reference un * TMA requires an alignment of 128 bits. As a result, for a type with `B` bits, `B x TILE_K` must be a multiple of 128 bits. +* The type of the scale and zero-point type must be two bytes or more. + +* The group size must be equal to gemm-k size (indicating a broadcast), or it must be a multiple of the threadblock-k size. + ## Upcoming features -* Support for applying scales after conversion, but before issuing tensor core math (input scale fusion) is planned for v3.4. +* Optimizations for memory bound cases. -* Many optimizations for SOL performance. +* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size. diff --git a/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.h b/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.hpp similarity index 53% rename from examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.h rename to examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.hpp index e19a01dc..106e9897 100644 --- a/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.h +++ b/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.hpp @@ -9,13 +9,15 @@ template __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer, - const QuantizedElement* q_buffer, - const OperandLayout operand_layout, - const ElementScale* scale_buffer, - const ScaleBroadCastLayout broadcasted_scale_layout, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleBroadCastLayout const broadcasted_scale_layout, ThrLayout thr_layout) { using namespace cute; @@ -33,6 +35,7 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer, // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting // It is expected that K % G == 0 Tensor gmem_scale_broadcasted = make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + Tensor gmem_zero_broadcasted = make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); // Assign 1 thread per element in the thread block auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); // @@ -41,16 +44,21 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer, // Tile across the block auto gOp_dq = local_tile(gmem_op_dq, blk_shape, blk_coord); auto gScale = local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gZero = local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); auto gOp_q = local_tile(gmem_op_q, blk_shape, blk_coord); auto tOpDq_gOpDq = local_partition(gOp_dq, thr_layout, threadIdx.x); auto tScale_gScale = local_partition(gScale, thr_layout, threadIdx.x); + auto tZero_gZero = local_partition(gZero, thr_layout, threadIdx.x); auto tOpQ_gOpQ = local_partition(gOp_q, thr_layout, threadIdx.x); // Make a fragment of registers to hold gmem loads Tensor rmem_op_q = make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); Tensor rmem_scale = make_fragment_like(tScale_gScale(_, _, _, 0)); + Tensor rmem_zero = make_fragment_like(tZero_gZero(_, _, _, 0)); Tensor rmem_op_dq = make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + Tensor rmem_op_scaled = make_fragment_like(rmem_op_dq); + Tensor rmem_zero_buf = make_fragment_like(rmem_zero); Tensor pred_id = make_identity_tensor(shape(operand_layout)); auto pred_blk_tile = local_tile(pred_id, blk_shape, blk_coord); @@ -63,8 +71,12 @@ __global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer, if (thread_offset < size<0>(operand_layout)) { copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); copy(tScale_gScale(_, _, _, ii), rmem_scale); - transform(rmem_op_q, rmem_op_dq, [] (const QuantizedElement& elt) { return DequantizedElement(elt); } ); - transform(rmem_op_dq, rmem_scale, rmem_op_dq, multiplies{}); + copy(tZero_gZero(_, _, _, ii), rmem_zero); + transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); + transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); + transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{}); + transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{}); + transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); } } @@ -74,12 +86,15 @@ template void dequantize_weight(DequantizedElement* dq_buffer, - const QuantizedElement* q_buffer, - const OperandLayout operand_layout, - const ElementScale* scale_buffer, - const ScaleLayout scale_layout) { + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + int const group_size) { using namespace cute; @@ -87,9 +102,9 @@ void dequantize_weight(DequantizedElement* dq_buffer, auto thr_layout = make_layout(make_shape(Int{})); const auto num_rows = get<0>(shape(operand_layout)); - const auto num_cols = get<1>(shape(operand_layout)); // [MN, K, L] + const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] - const auto num_cols_scale = get<1>(shape(scale_layout)); // [MN, G, L] + const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] if (num_rows != size<0>(scale_layout)) { std::cerr << "Invalid first dimension for scales. Must match first dim for weights." @@ -98,81 +113,18 @@ void dequantize_weight(DequantizedElement* dq_buffer, exit(-1); } - if (num_cols % num_cols_scale != 0) { - std::cerr << "Invalid shape for weight / scales. Weight cols must be a multiple of scale cols." - << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) - << std::endl; - exit(-1); - } - const auto scale_stride0 = get<0>(stride(scale_layout)); const auto scale_stride1 = get<1>(stride(scale_layout)); const auto scale_stride2 = get<2>(stride(scale_layout)); - auto scale_shape_bcast = make_shape(num_rows, make_shape(num_cols / num_cols_scale, num_cols_scale), batches); + auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); - const auto blocks_x = num_cols; + const auto blocks_x = gemm_k; const auto blocks_y = batches; dim3 blocks(blocks_x, blocks_y, 1); - dequantize_weight_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, scale_layout_bcast, thr_layout); - CUDA_CHECK(cudaDeviceSynchronize()); -} - - -template -__global__ void pack_data_kernel(SubbyteType* packed_data_ptr, - const uint8_t* unpacked_data_ptr, - const size_t max_elts) { - using namespace cute; - - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - uint8_t data[ELTS_PER_THREAD]; - if (tid < max_elts) { - const uint8_t* read_ptr = unpacked_data_ptr + tid * ELTS_PER_THREAD; - for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) { - data[ii] = read_ptr[ii]; - } - - - using WriteType = cute::array_subbyte; - WriteType* write_ptr = reinterpret_cast(packed_data_ptr); - - WriteType packed_data; - for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) { - SubbyteType elt(data[ii]); - packed_data[ii] = elt; - } - write_ptr[tid] = packed_data; - } - -} - -template -void pack_data(SubbyteType* packed_data, const uint8_t* unpacked_data, const OperandLayout operand_layout) { - static_assert(cute::sizeof_bits_v < 8, "First operand must be a sub-byte type"); - constexpr int packed_elements = 8 / cute::sizeof_bits_v; - - if (cute::stride<0>(operand_layout) == 1 && (cute::shape<0>(operand_layout) % packed_elements)) { - std::cerr << "Invalid shape / stride for dimension 0. Contiguous dimension must be a multiple of " - << packed_elements << std::endl; - exit(-1); - } - - if (cute::stride<1>(operand_layout) == 1 && (cute::shape<1>(operand_layout) % packed_elements)) { - std::cerr << "Invalid shape / stride for dimension 1. Contiguous dimension must be a multiple of " - << packed_elements << std::endl; - exit(-1); - } - - const int64_t total_threads = cute::size(operand_layout) / packed_elements; - - const int threads_per_block = 256; - const int64_t num_blocks = (total_threads + threads_per_block - 1) / threads_per_block; - pack_data_kernel<<>>(packed_data, unpacked_data, total_threads); + dequantize_weight_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); CUDA_CHECK(cudaDeviceSynchronize()); } diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu new file mode 100644 index 00000000..2ef3864f --- /dev/null +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -0,0 +1,520 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. + + This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA + warp-specialized cooperative kernel. + The new feature showcased in this example is on-the-fly modification of TMA descriptors + to move between batches (represented by l). + + To run this example: + + $ ./examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm --m=2048 --n=2048 --k=2048 --l=10 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelArrayTmaWarpSpecializedCooperative; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedArray; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 10; + int m = 1024, n = 512, k = 1024, l = 10; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "56_hopper_ptr_array_batched_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the batch count for Ptr-Array GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "56_hopper_ptr_array_batched_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + for (int32_t i = 0; i < options.l; ++i) { + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = options.m * options.k; + int64_t elements_B = options.k * options.n; + int64_t elements_C = options.m * options.n; + int64_t elements_D = options.m * options.n; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.l); + std::vector ptr_B_host(options.l); + std::vector ptr_C_host(options.l); + std::vector ptr_D_host(options.l); + + for (int32_t i = 0; i < options.l; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(options.l); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.l); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.l); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.l); + ptr_D.copy_from_host(ptr_D_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kArray, + {{options.m, options.n, options.k, options.l}}, + {ptr_A.get(), stride_A, ptr_B.get(), stride_B}, + {{options.alpha, options.beta}, ptr_C.get(), stride_C, ptr_D.get(), stride_D}, + hw_info + }; + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.l; ++i) { + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), options.m * options.n); + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Batches : " << options.l << std::endl; + std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl; + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt new file mode 100644 index 00000000..5fa5e62b --- /dev/null +++ b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt @@ -0,0 +1,52 @@ + +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + +set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=0) # Square problem sizes +set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=0) # Square problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Default problem sizes +set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=0) # Default problem sizes + +set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=0) # Small-k problem sizes +set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=0) # Small-k problem sizes + +cutlass_example_add_executable( + 56_hopper_ptr_array_batched_gemm + 56_hopper_ptr_array_batched_gemm.cu + TEST_COMMAND_OPTIONS + TEST_SQUARE + TEST_SQUARE_LARGE_BATCH + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_BATCH + TEST_SMALLK + TEST_SMALLK_LARGE_BATCH + ) diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu new file mode 100644 index 00000000..565a0813 --- /dev/null +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -0,0 +1,677 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. + + This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA + warp-specialized cooperative kernel. + For this example all scheduling work is performed on the device. + The new feature showcased in this example is on-the-fly modification of TMA descriptors + to move between groups/problem_count (represented by groups). + + To run this example: + + $ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using ElementC = float; // Element type for C and D matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_256,_128,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum; // Kernel to launch +using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecializedGroup; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 10; + int m = 1024, n = 2048, k = 512, groups = 10; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1; + int cmd_line_n = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + + if (m < 1) { + m = ((rand() % 512) + 1); + } + + if (n < 1) { + n = ((rand() % 512) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "57_hopper_grouped_gemm\n\n" + << " Hopper FP8 Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size.\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "57_hopper_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += cute::size(problem); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{}))); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{}))); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{}))); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{}))); + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {{options.alpha, options.beta}, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {M, N, K}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + #if 0 + std::cout << "Group: " << i << " Status: " << passed << std::endl; + #endif + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Problem Sizes: " << std::endl; + for (auto const & problem : options.problem_sizes_host) { + std::cout << " " << problem << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl; + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/57_hopper_grouped_gemm/CMakeLists.txt b/examples/57_hopper_grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000..d9d281ba --- /dev/null +++ b/examples/57_hopper_grouped_gemm/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + +set(TEST_RANDOM --iterations=0) # Random problem sizes +set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.7 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes + +set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes +set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes + +cutlass_example_add_executable( + 57_hopper_grouped_gemm + 57_hopper_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP + ) diff --git a/examples/60_cutlass_import/CMakeLists.txt b/examples/60_cutlass_import/CMakeLists.txt index 8cf99195..381c75b2 100644 --- a/examples/60_cutlass_import/CMakeLists.txt +++ b/examples/60_cutlass_import/CMakeLists.txt @@ -37,9 +37,10 @@ cmake_minimum_required(VERSION 3.18 FATAL_ERROR) -project(cutlass_import_example VERSION 0.1 LANGUAGES CXX CUDA) +project(cutlass_import_example VERSION 0.2 LANGUAGES CXX CUDA) -if (DEFINED CUTLASS_DIR) +if (CUTLASS_DIR) + message(STATUS "Using CUTLASS specified at ${CUTLASS_DIR}.") list(APPEND CMAKE_PREFIX_PATH ${CUTLASS_DIR}) endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 47445ca2..ebcaef18 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -44,7 +44,7 @@ function(cutlass_example_add_executable NAME) set(__DISABLE_TESTS OFF) endif() - cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS}) + cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS} BATCH_SOURCES OFF) add_dependencies(cutlass_examples ${NAME}) @@ -136,6 +136,8 @@ foreach(EXAMPLE 53_hopper_gemm_permute 54_hopper_fp8_warp_specialized_gemm 55_hopper_mixed_dtype_gemm + 56_hopper_ptr_array_batched_gemm + 57_hopper_grouped_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt index d27035fd..bf95576a 100644 --- a/examples/cute/tutorial/CMakeLists.txt +++ b/examples/cute/tutorial/CMakeLists.txt @@ -27,7 +27,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + cutlass_example_add_executable( sgemm_nt_1 sgemm_nt_1.cu ) + +cutlass_example_add_executable( + tiled_copy + tiled_copy.cu +) + diff --git a/examples/cute/tutorial/tiled_copy.cu b/examples/cute/tutorial/tiled_copy.cu new file mode 100644 index 00000000..403be25c --- /dev/null +++ b/examples/cute/tutorial/tiled_copy.cu @@ -0,0 +1,260 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +// This is a simple tutorial showing several ways to partition a tensor into tiles then +// perform efficient, coalesced copies. This example also shows how to vectorize accesses +// which may be a useful optimization or required for certain workloads. +// +// `copy_kernel()` and `copy_kernel_vectorized()` each assume a pair of tensors with +// dimensions (m, n) have been partitioned via `tiled_divide()`. +// +// The result are a part of compatible tensors with dimensions ((M, N), m', n'), where +// (M, N) denotes a statically sized tile, and m' and n' denote the number of such tiles +// within the tensor. +// +// Each statically sized tile is mapped to a CUDA threadblock which performs efficient +// loads and stores to Global Memory. +// +// `copy_kernel()` uses `cute::local_partition()` to partition the tensor and map +// the result to threads using a striped indexing scheme. Threads themselve are arranged +// in a (ThreadShape_M, ThreadShape_N) arrangement which is replicated over the tile. +// +// `copy_kernel_vectorized()` uses `cute::make_tiled_copy()` to perform a similar +// partitioning using `cute::Copy_Atom` to perform vectorization. The actual vector +// size is defined by `ThreadShape`. +// +// This example assumes the overall tensor shape is divisible by the tile size and +// does not perform predication. + + +/// Simple copy kernel. +// +// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N). +template +__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout) +{ + using namespace cute; + + // Slice the tiled tensors + Tensor tile_S = S(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + Tensor tile_D = D(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + + // Construct a partitioning of the tile among threads with the given thread arrangement. + + // Concept: Tensor Layout Index + Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x); + Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x); + + // Construct a register-backed Tensor with the same shape as each thread's partition + auto fragment = make_fragment_like(thr_tile_S); + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy(thr_tile_S, fragment); + copy(fragment, thr_tile_D); +} + +/// Vectorized copy kernel. +/// +/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation +/// has the precondition that pointers are aligned to the vector size. +/// +template +__global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout) +{ + using namespace cute; + using Element = typename TensorS::value_type; + + // Slice the tensors to obtain a view into each tile. + Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + + // Define `AccessType` which controls the size of the actual memory access. + using AccessType = cutlass::AlignedArray; + + // A copy atom corresponds to one hardware memory access. + using Atom = Copy_Atom, Element>; + + // Construct tiled copy, a tiling of copy atoms. + // + // Note, this assumes the vector and thread layouts are aligned with contigous data + // in GMEM. Alternative thread layouts are possible but may result in uncoalesced + // reads. Alternative vector layouts are also possible, though incompatible layouts + // will result in compile time errors. + auto tiled_copy = + make_tiled_copy( + Atom{}, // access size + ThreadLayout{}, // thread layout + VecLayout{}); // vector layout (e.g. 4x1) + + // Construct a Tensor corresponding to each thread's slice. + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + Tensor thr_tile_S = thr_copy.partition_S(tile_S); + Tensor thr_tile_D = thr_copy.partition_D(tile_D); + + // Construct a register-backed Tensor with the same shape as each thread's partition + auto fragment = make_fragment_like(thr_tile_D); + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy(tiled_copy, thr_tile_S, fragment); + copy(tiled_copy, fragment, thr_tile_D); +} + +/// Helper to convert a shape to a dim3 +template +dim3 shape_to_dim3(Shape shape) +{ + using namespace cute; + + CUTE_STATIC_ASSERT_V(rank(shape) <= Int<3>{}); + auto result = append<3>(product_each(shape), 1u); + + return dim3(get<0>(result), get<1>(result), get<2>(result)); +} + +/// Main function +int main(int argc, char** argv) +{ + // + // Given a 2D shape, perform an efficient copy + // + + using namespace cute; + using Element = float; + + // Define a tensor shape with dynamic extents (m, n) + auto tensor_shape = make_shape(256, 512); + + thrust::host_vector h_S(size(tensor_shape)); + thrust::host_vector h_D(size(tensor_shape)); + + // + // Initialize + // + + for (size_t i = 0; i < h_S.size(); ++i) { + h_S[i] = static_cast(i); + h_D[i] = Element{}; + } + + thrust::device_vector d_S = h_S; + thrust::device_vector d_D = h_D; + + // + // Make tensors + // + + Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape)); + Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape)); + + // + // Partition + // + + + // Define a statically sized block (M, N). + // + // Note, by convention, capital letters are used to represent static modes. + auto block_shape = make_shape(Int<128>{}, Int<64>{}); + + if ((get<0>(tensor_shape) % get<0>(block_shape)) || (get<1>(tensor_shape) % get<1>(block_shape))) { + std::cerr << "The tensor shape must be divisible by the block shape." << std::endl; + return -1; + } + + // Tile the tensor (m, m) ==> ((M, N), m', n') where (M, N) is the static tile + // shape, and modes (m', n') correspond to the number of tiles. + // + // These will be used to determine the CUDA kernel grid dimensinos. + Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); + Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); + + // Thread arrangement + Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); + + // Vector dimensions + Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); + + // + // Determine grid and block dimensions + // + + dim3 gridDim = shape_to_dim3(select<1,2>(shape(tiled_tensor_D))); // Grid shape corresponds to modes m' and n' + dim3 blockDim(size(shape(thr_layout))); + + // + // Launch the kernel + // + copy_kernel_vectorized<<< gridDim, blockDim >>>( + tiled_tensor_S, + tiled_tensor_D, + thr_layout, + vec_layout); + + cudaError result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl; + return -1; + } + + // + // Verify + // + + h_D = d_D; + + int32_t errors = 0; + int32_t const kErrorLimit = 10; + + for (size_t i = 0; i < h_D.size(); ++i) { + if (h_S[i] != h_D[i]) { + std::cerr << "Error. S[" << i << "]: " << h_S[i] << ", D[" << i << "]: " << h_D[i] << std::endl; + + if (++errors >= kErrorLimit) { + std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl; + return -1; + } + } + } + + std::cout << "Success." << std::endl; + + return 0; +} + diff --git a/examples/python/00_basic_gemm.ipynb b/examples/python/00_basic_gemm.ipynb index afd6dab4..428d28f0 100644 --- a/examples/python/00_basic_gemm.ipynb +++ b/examples/python/00_basic_gemm.ipynb @@ -357,7 +357,7 @@ "## Handling errors\n", "The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n", "\n", - "Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user." + "Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user. Uncomment and run the code below to see this error." ] }, { @@ -371,6 +371,75 @@ "# td.stages = 8\n", "# plan.compile(td)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specializations for other data types\n", + "\n", + "Various CUTLASS kernels specialized for specific data types can also be run via the Python interface.\n", + "\n", + "For example, the code below shows how to declare and run a GEMM using the 3xTF32 feature (see corresponding C++ example [here](https://github.com/NVIDIA/cutlass/blob/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu))." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cutlass.backend.utils.device import device_cc\n", + "\n", + "# 3xTF32 requires SM80 or higher\n", + "if device_cc() >= 80:\n", + " plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)\n", + " plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32\n", + "\n", + " # Create input/output tensors in FP32\n", + " A, B = [np.ones((128, 128)).astype(np.float32) for _ in range(2)]\n", + " C, D = [np.zeros((128, 128)).astype(np.float32) for _ in range(2)]\n", + "\n", + " # Run the GEMM\n", + " plan.run(A, B, C, D, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additionally, one can run CUTLASS's FP8 GEMMs if using a frontend library capable of allocating and initializing FP8 tensors (e.g., PyTorch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import torch\n", + "except ImportError:\n", + " print(\"PyTorch is not available. Skipping FP8 example\")\n", + " import sys; sys.exit(0)\n", + "\n", + "if not hasattr(torch, \"float8_e4m3fn\"):\n", + " print(\"Version of PyTorch does not have the float8_e4m3fn data type. Skipping FP8 example\")\n", + " import sys; sys.exit(0)\n", + "\n", + "# FP8 is supported through the CUTLASS Python interface on SM90 and higher\n", + "if device_cc() >= 90:\n", + " plan = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n", + " layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,\n", + " layout_C=cutlass.LayoutType.ColumnMajor)\n", + "\n", + " # Create input/output tensors in FP8\n", + " A, B = [torch.ones((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n", + " C, D = [torch.zeros((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n", + "\n", + " # Run the GEMM\n", + " plan.run(A, B, C, D, print_module=print_module)" + ] } ], "metadata": { diff --git a/examples/python/02_pytorch_extension_grouped_gemm.ipynb b/examples/python/02_pytorch_extension_grouped_gemm.ipynb index 9196af13..b811c5e3 100644 --- a/examples/python/02_pytorch_extension_grouped_gemm.ipynb +++ b/examples/python/02_pytorch_extension_grouped_gemm.ipynb @@ -134,7 +134,7 @@ "id": "590a3bc5", "metadata": {}, "source": [ - "We'll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch." + "We'll next run a group of 20 GEMMs via the CUTLASS Python interface and via PyTorch." ] }, { @@ -144,7 +144,7 @@ "metadata": {}, "outputs": [], "source": [ - "As, Bs, Cs, Ds, = generate_problems(50)\n", + "As, Bs, Cs, Ds, = generate_problems(20)\n", "\n", "plan.run(As, Bs, Cs, Ds, print_module=True)\n", "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp index 40b9e2c2..57c4af75 100644 --- a/include/cute/arch/cluster_sm90.hpp +++ b/include/cute/arch/cluster_sm90.hpp @@ -95,6 +95,7 @@ CUTE_DEVICE dim3 cluster_grid_dims() return gridDim; #elif defined(_MSC_VER) CUTE_RUNTIME_ASSERT("cluster_grid_dims() can only be called on device"); + return {0, 0, 0}; #else return {0, 0, 0}; #endif @@ -114,6 +115,7 @@ CUTE_DEVICE dim3 cluster_id_in_grid() return blockIdx; #elif defined(_MSC_VER) CUTE_RUNTIME_ASSERT("cluster_id_in_grid() can only be called on device"); + return {0, 0, 0}; #else return {0, 0, 0}; #endif diff --git a/include/cute/arch/copy_sm90.hpp b/include/cute/arch/copy_sm90.hpp index 6ac96438..ab15eb6d 100644 --- a/include/cute/arch/copy_sm90.hpp +++ b/include/cute/arch/copy_sm90.hpp @@ -40,6 +40,11 @@ # define CUTE_ARCH_TMA_SM90_ENABLED #endif +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) && \ + ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + namespace cute { diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index d33ed305..0b6d40e3 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -42,6 +42,7 @@ #include #include +#include #include // to_Format<[u]intX> #include // to_Format @@ -200,6 +201,141 @@ prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) #endif } +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a TensorMap modification (by each field) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Replace tensor pointer directly in GMEM +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;" + :: "l"(gmem_int_desc), "l"(new_desc_addr)); +#else + CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + uint64_t const smem_int64_desc = 0; + asm volatile ( + "cvt.u64.u32 %0, %1;" + :: "l"(smem_int64_desc), "r"(smem_int_desc)); + asm volatile ( + "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" + :: "l"(smem_int64_desc), "l"(new_desc_addr)); +#else + CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc, + cute::array const& prob_shape, + cute::array const& prob_stride) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const smem_int64_desc = 0; + asm volatile ( + "cvt.u64.u32 %0, %1;" + :: "l"(smem_int64_desc), "r"(smem_int_desc)); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[0])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[1])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[2])); + // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); +#else + CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + asm volatile ( + "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" + :: "l"(gmem_int_desc), "r"(smem_int_desc)); +#else + CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a release fence operation (needed when modifying tensormap directly in GMEM) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_release() +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + asm volatile ("fence.proxy.tensormap::generic.release.gpu;"); +#else + CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a acquire fence operation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" + : + : "l"(gmem_int_desc) + : "memory"); + asm volatile ( + "cvta.global.u64 %0, %0;" + : + : "l"(gmem_int_desc), "l"(gmem_int_desc) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Using TMA Descriptor modification without CUTE_ARCH_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + /////////////////////////////////////////////////////////////////////////////// } // end namespace cute diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index fcb9189f..5a362e66 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -775,6 +775,104 @@ struct SM90_TMA_STORE } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE im2col: Initiates a TMA copy, in im2col mode, from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_n) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n) + : "memory"); +#else + CUTE_RUNTIME_ASSERT("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_3D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_4D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_5D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n); + } +}; + // Indicate arrival of warp issuing TMA_STORE CUTE_HOST_DEVICE static void tma_store_arrive() { diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp index 8dc5fdcb..9488a843 100644 --- a/include/cute/arch/mma_sm80.hpp +++ b/include/cute/arch/mma_sm80.hpp @@ -1,4 +1,4 @@ - /************************************************************************************************** +/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index de320a22..76c48c2b 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -31,26 +31,28 @@ #pragma once #include + #include -#include - #include +#include #include +#include + namespace cute { template struct Copy_Atom; -template -struct Copy_Atom : Copy_Atom, T> +template +struct Copy_Atom : Copy_Atom, CopyInternalType> {}; -template -struct Copy_Atom, T> +template +struct Copy_Atom, CopyInternalType> : Copy_Traits { using Traits = Copy_Traits; @@ -61,7 +63,7 @@ struct Copy_Atom, T> using BitLayoutDst = typename Traits::DstLayout; using BitLayoutRef = typename Traits::RefLayout; - using ValType = T; + using ValType = CopyInternalType; using ValLayoutSrc = decltype(upcast::value>(BitLayoutSrc{})); using ValLayoutDst = decltype(upcast::value>(BitLayoutDst{})); @@ -80,7 +82,7 @@ struct Copy_Atom, T> auto with(TraitsArgs&&... args) const { auto traits = Traits::with(std::forward(args)...); - return Copy_Atom{traits}; + return Copy_Atom{traits}; } // @@ -88,19 +90,19 @@ struct Copy_Atom, T> // // Check and call instruction, or recurse - template + template CUTE_HOST_DEVICE void - call(Tensor const& src, - Tensor & dst) const + call(Tensor const& src, + Tensor & dst) const { static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); if constexpr (is_constant::value || is_constant::value) { - // Dispatch to unpack for instruction + // Dispatch to unpack to execute instruction return copy_unpack(*this, src, dst); } else if constexpr (is_tuple::value && @@ -110,7 +112,7 @@ struct Copy_Atom, T> // ((A,B,C,...)) -> (A,B,C,...) return copy(*this, tensor<0>(src), tensor<0>(dst)); } else { - static_assert(sizeof(TS) < 0, "No instruction match and no recursion possible."); + static_assert(dependent_false, "No instruction match and no recursion possible."); } } @@ -135,7 +137,7 @@ struct ThrCopy; template coord [Need not be 2D...] - class ShapeTile_MN> // coord space + class ShapeTiler_MN> // coord space struct TiledCopy : Copy_Atom { // Layout information from the CopyAtom @@ -148,8 +150,7 @@ struct TiledCopy : Copy_Atom using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); // Layout information for the TiledCopy - using Tiler_MN = ShapeTile_MN; - using TiledShape_MN = decltype(shape(ShapeTile_MN{})); + using Tiler_MN = ShapeTiler_MN; using TiledLayout_TV = LayoutCopy_TV; using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); @@ -172,12 +173,9 @@ struct TiledCopy : Copy_Atom auto tidfrg_S(STensor&& stensor) { - constexpr int R = remove_cvref_t::rank; - static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); - // Generalize the dimension checks for arbitrary rank - //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + // Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); } @@ -196,17 +194,14 @@ struct TiledCopy : Copy_Atom auto tidfrg_D(DTensor&& dtensor) { - constexpr int R = remove_cvref_t::rank; - static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); - // Generalize the dimension checks for arbitrary rank - //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + // Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); } // Tile a tensor or a layout from shape - // (Tile,(RestM,RestN,...)) + // ((TileM,TileN,...), (RestM,RestN,...)) // to shape // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) template @@ -232,7 +227,7 @@ struct TiledCopy : Copy_Atom // Transform the tile mode auto tv_tensor = tensor.compose(thrval2mn, _); - // ((thrid,val),(RM,RN,...)) + // ((thrid,val),(RestM,RestN,...)) // Unfold and return return tv_tensor(make_coord(_,_), _); @@ -253,7 +248,7 @@ struct TiledCopy : Copy_Atom auto V = size<0>(tensor); - auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{})); + auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); @@ -278,7 +273,7 @@ struct TiledCopy : Copy_Atom get_layoutS_TV() { // (M,N) -> (M,N) - auto ref_S = make_layout(make_shape(TiledShape_MN{}, Int<1>{})); + auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); // (thr_idx,val_idx) -> (M,N) return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); } @@ -290,7 +285,7 @@ struct TiledCopy : Copy_Atom // (thr_idx,val_idx) -> (M,N) auto layoutS_TV = get_layoutS_TV(); // (M,K) -> (thr_idx,val_idx) - auto layoutS_MK = right_inverse(layoutS_TV).with_shape(TiledShape_MN{}); + auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{})); // athrid = (v,m,k) -> thr_idx auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); @@ -303,7 +298,7 @@ struct TiledCopy : Copy_Atom get_layoutD_TV() { // (M,N) -> (M,N) - auto ref_D = make_layout(make_shape(TiledShape_MN{}, Int<1>{})); + auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); // (thr_idx,val_idx) -> (M,N) return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); } @@ -315,7 +310,7 @@ struct TiledCopy : Copy_Atom // (thr_idx,val_idx) -> (M,N) auto layoutD_TV = get_layoutD_TV(); // (M,K) -> (thr_idx,val_idx) - auto layoutD_MK = right_inverse(layoutD_TV).with_shape(TiledShape_MN{}); + auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{})); // athrid = (v,m,k) -> thr_idx auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); @@ -406,51 +401,44 @@ make_tiled_copy_impl(Copy_Atom const& atom, // These tile the Copy_Atom as a whole // -template +template CUTE_HOST_DEVICE auto -make_tiled_copy_A(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) +make_tiled_copy_A(Copy_Atom const& copy_atom, + TiledMMA const& mma) { - using MNK = typename TiledMMA::TiledShape_MNK; - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{}))); + return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma))); } -template +template CUTE_HOST_DEVICE auto -make_tiled_copy_B(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) +make_tiled_copy_B(Copy_Atom const& copy_atom, + TiledMMA const& mma) { - using MNK = typename TiledMMA::TiledShape_MNK; - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{}))); + return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma))); } -template +template CUTE_HOST_DEVICE auto -make_tiled_copy_C(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) +make_tiled_copy_C(Copy_Atom const& copy_atom, + TiledMMA const& mma) { - using MNK = typename TiledMMA::TiledShape_MNK; - return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{}))); + return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma))); } // returns the smallest tiled copy that can retile LayoutC_TV // for use with pipelined epilogues with subtiled stores -template +template CUTE_HOST_DEVICE auto -make_tiled_copy_C_atom(Copy_Atom const& copy_atom, - TiledMMA const& tiled_mma) +make_tiled_copy_C_atom(Copy_Atom const& copy_atom, + TiledMMA const& mma) { // Truncate the V-layout to just the Copy_Atom, keep the V-order - auto layoutC_TV = tiled_mma.get_layoutC_TV(); - auto copy_V = Int::NumValSrc>{}; + auto layoutC_TV = mma.get_layoutC_TV(); + auto copy_V = Int::NumValSrc>{}; CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); @@ -458,8 +446,7 @@ make_tiled_copy_C_atom(Copy_Atom const& copy_atom, // Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them // Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA - using MNK = typename TiledMMA::TiledShape_MNK; - auto mma_tiler = make_shape(size<0>(MNK{}),size<1>(MNK{})); + auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma)); auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); auto tiler = transform(make_seq{}, [&](auto i) { @@ -474,8 +461,6 @@ make_tiled_copy_C_atom(Copy_Atom const& copy_atom, // (tid,vid) -> tile_coord auto layout_tv = composition(left_inverse(tile2mma), layout_TV); - - using MNK = typename TiledMMA::TiledShape_MNK; return make_tiled_copy_impl(copy_atom, layout_tv, tiler); } @@ -655,8 +640,10 @@ print(TiledCopy const& copy, char const* pad = "") template CUTE_HOST_DEVICE void -print(ThrCopy const&) +print(ThrCopy const& thr_copy) { + print("ThrCopy\n"); + print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n"); print(TiledCopy{}); } diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 132ba520..2ae6e1e8 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -43,10 +43,14 @@ namespace cute { -template +template struct AuxTmaParams { - using GmemStrides = GmemStrides_; + using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic GmemStrides g_stride_; + using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static + static_assert(is_static::value); + using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle + static_assert(is_static::value); }; ////////////////////////////////////////////////////////////////////////////// @@ -138,13 +142,19 @@ struct Copy_Traits // Construct an executable SM90_TMA_LOAD with tma_mbar CUTE_HOST_DEVICE constexpr Copy_Traits - with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const { + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { // We accept multicast_mask here to keep the API for both atoms consistent - // assert(multicast_mask == 0); - (void) multicast_mask; return {tma_desc_, tma_mbar}; } + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {*new_tma_desc, tma_mbar}; + } + // Generate the TMA coord tensor template CUTE_HOST_DEVICE constexpr @@ -251,6 +261,13 @@ struct Copy_Traits return {tma_desc_, tma_load_mbar, multicast_mask}; } + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { + return {*new_tma_desc, tma_load_mbar, multicast_mask}; + } + // Generate the TMA coord tensor template CUTE_HOST_DEVICE constexpr @@ -508,51 +525,91 @@ coalesce_256(Layout const& layout) return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride)); } -template -CUTE_HOST_DEVICE constexpr -auto -coalesce_256(Tensor const& tensor) -{ - return make_tensor(tensor.data(), coalesce_256(tensor.layout())); -} - - -// Use a smem_inv_h to read through the GMEM tensor -// and construct a TMA Descriptor for the resulting instruction -// At the same time, construct the Tma Tensor's Stride to generate -// the TMA coordinates that the instruction consumes. -// template -CUTE_HOST_RTC + class VShape, class VStride> +CUTE_HOST_DEVICE constexpr auto -make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor - Layout const& smem_inv_h, // smem_idx to hier gmode - Swizzle const& swizzle) // Swizzle fn on smem_idx +construct_tma_gbasis(Tensor const& gtensor, // The original GMEM Tensor + Layout const& slayout, // The layout of SMEM + Layout const& cta_v_map) // smem_idx to hier gmode { + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + +#if 0 + print("gtensor : "); print(gtensor); print("\n"); + print("slayout : "); print(slayout); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); +#endif + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + + // Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode + // smem idx -> gmem mode + auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("inv_smem_layout : "); print(inv_smem_layout); print("\n"); + print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n"); +#endif + + // + // TMA gtensor truncation + // + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); + + // Keep only the static-1 basis modes into gmem + auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full); + +#if 0 + print("smem_rank : "); print(smem_rank); print("\n"); + print("sidx2gmode : "); print(sidx2gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + // The smem vector is the same units as gtensor, so compose first and then recast // tma_val_idx:gmem_strides - Tensor tile_gstride = recast(gtensor.compose(smem_inv_h)); + auto tile_gstride = recast(gtensor.compose(sidx2gmode)).layout(); // Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType) // tma_box_shape:gmem_strides - Tensor tma_gstride = coalesce_256(tile_gstride); + auto tma_gstride = coalesce_256(tile_gstride); - // Perform the tiling to the gmem vector again, but with indirections to the gtensor modes + // Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes auto gbasis = make_identity_layout(shape(gtensor)); - auto tile_gbasis_tmp = gbasis.compose(smem_inv_h); + auto tile_gbasis_tmp = gbasis.compose(sidx2gmode); // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape // tma_box_shape:gmem_mode auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp)); - // Recast the original tensor for shape inspections - auto gtensor_T = recast(gtensor); + // "Coalesce" the tile basis into a compatible shape with the tma_gstride + auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); // Find missing bases that don't appear in tile_gbasis - // NOTE This is essentially ArithmeticTuple complement... - // NOTE in pursuit of implementing an ArithmeticTuple logical_divide for smem_inv_h auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)), flatten(stride(gbasis)), [&](auto s, auto d, auto e) @@ -561,7 +618,7 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM return cute::tuple<>{}; // If size-1 or stride-0, then don't append } else { using E = decltype(e); - auto has_e = any_of(stride(tile_gbasis), [] (auto tb) { return tb == E{}; }); + auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; }); if constexpr (decltype(has_e)::value) { return cute::tuple<>{}; // If d was found, then don't append } else { @@ -569,13 +626,10 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM } } }); - auto tile_gbasis_remaining_rank = rank(tile_gbasis_remaining_stride); - - // "Coalesce" the tile basis into a compatible shape with the tma - auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); // Append the remaining basis modes that contribute to the TMA with size-1 - auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(repeat(Int<1>{}))), + auto tile_gbasis_remaining_shape = repeat(Int<1>{}); + auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )), tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride))); // Group the trailing modes to make this max rank-5 -- TMA rank limitation @@ -583,15 +637,98 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM auto tma_gbasis = group(tma_gbasis_full); #if 0 - print("smem_inv_h : "); print(smem_inv_h); print("\n"); - print("gtensor : "); print(gtensor); print("\n"); print("tile_gstride : "); print(tile_gstride); print("\n"); print("tma_gstride : "); print(tma_gstride); print("\n"); print("gbasis : "); print(gbasis); print("\n"); - print("tile_gbasis : "); print(tile_gbasis); print("\n"); + print("tile_gbasis : "); print(tma_gbasis_tile); print("\n"); print("tma_gbasis : "); print(tma_gbasis); print("\n"); #endif + return tma_gbasis; +} + +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Tensor const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType + TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s) + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + static_assert(is_tuple::value); + static_assert(is_same::value || is_same::value); + + using TmaInternalType = typename GEngine::value_type; + constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value; + static_assert(TmaRank >= tma_rank); + + auto gmem_shape = shape(gtensor); + auto gmem_stride = stride(gtensor); + // Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides + for_each(make_seq{}, [&](auto i) { + constexpr int tma_i_rank = decltype(rank(tma_gbasis_stride))::value; + if constexpr (tma_i_rank == 1) { + // Trivial contribution of this gmem mode to this tma mode + auto ej = unwrap(get(tma_gbasis_stride)); + gmem_prob_shape[i] = basis_get(ej, gmem_shape); + gmem_prob_stride[i] = basis_get(ej, gmem_stride) * sizeof_bits_v / 8; + } else { + // Apply a recurrence to each gmem mode that contributes to this tma mode + for_each(get(tma_gbasis_stride), [&](auto ej) { + // Problem shape + uint64_t shape_j = basis_get(ej, gmem_shape); + // Problem stride (in bytes) + uint64_t stride_j = basis_get(ej, gmem_stride) * sizeof_bits_v / 8; + + uint64_t old_stride = gmem_prob_stride[i]; + gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); + + if (gmem_prob_stride[i] != 0) { + // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 + gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) + + (shape_j-1) * (stride_j / gmem_prob_stride[i]) + + 1; + } else { + gmem_prob_shape[i] = shape_j; + } + }); + } + }); +} + +// Overload for an existing Copy_Traits +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Copy_Traits const& tma_traits, + Tensor const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}), + gmem_prob_shape, gmem_prob_stride); +} + +// Use a sidx2gmode to read through the GMEM tensor +// and construct a TMA Descriptor for the resulting instruction +// At the same time, construct the Tma Tensor's Stride to generate +// the TMA coordinates that the instruction consumes. +// +template +CUTE_HOST_RTC +auto +make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor + Layout const& tma_gbasis, // TMA mode -> GMEM mode mapping + Swizzle const& swizzle, // Swizzle fn on smem_idx + uint32_t num_multicast) // The number of CTAs in multicasting +{ // // TMA desc creation // @@ -602,31 +739,16 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM // TMA gmem desc info // + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data()); auto gmem_layout = gtensor_T.layout(); cute::array gmem_prob_shape = {1,1,1,1,1}; cute::array gmem_prob_stride = {0,0,0,0,0}; - // Use the indirections in tma_gbasis in the values of flat_glayout to construct the gmem shapes/strides - for_each(make_seq{}, [&](auto i) { - for_each(stride(tma_gbasis), [&](auto ej) { - // Problem stride - uint64_t stride_j = ceil_div(basis_get(ej, stride(gmem_layout)) * sizeof_bits_v, 8); - uint64_t old_stride = gmem_prob_stride[i]; - gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); - // Problem shape - uint64_t shape_j = basis_get(ej, shape(gmem_layout)); - if (gmem_prob_stride[i] != 0) { - // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 - gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) - + (shape_j-1) * (stride_j / gmem_prob_stride[i]) - + 1; - } else { - gmem_prob_shape[i] = shape_j; - } - }); - }); + fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride); assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned @@ -663,6 +785,13 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM for_each(make_seq{}, [&](auto i) { smem_box_shape[i] *= size(tma_gbasis); }); + // Finally, truncate the tma box by the num_multicast + for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) { + assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0); + uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]); + smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast); + multicast = new_mult; + } assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1 assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 @@ -740,26 +869,27 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM auto recast_ratio = cute::ratio(Int::value>{}, Int::value>{}); + auto gbasis = make_basis_like(shape(gtensor)); + // Finally, get the inverse permutation of the E bases for the mocked gmem stride - // NOTE This is essentially ArithmeticTuple inverse... - auto gmem_stride_bases = transform_leaf(stride(gbasis), [&](auto ei) { + auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) { auto si = basis_get(ei, shape(gmem_layout)); auto di = basis_get(ei, stride(gmem_layout)); if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) { return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA } else { - auto tma_gbasis_stride = stride(tma_gbasis); + auto tma_gmem_basis_stride = stride(tma_gbasis); // Find j such that E is in stride(tma_gbasis) using EI = decltype(ei); - [[maybe_unused]] auto j = find_if(tma_gbasis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); - if constexpr (decltype(j == rank(tma_gbasis_stride))::value) { + [[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); + if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) { return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA } else if constexpr (decltype(j == Int<0>{})::value) { auto scale = recast_ratio * basis_get(ei, stride(gtensor)); return E{} * scale; // Return TMA Coord basis -- with a recast scale factor } else - if constexpr (decltype(rank(tma_gbasis_stride) == Int<1>{})::value) { + if constexpr (decltype(rank(tma_gmem_basis_stride) == Int<1>{})::value) { return E{}; // Return TMA Coord basis -- known scale of Int<1>{} } else { int32_t scale = ceil_div(int32_t(di * sizeof_bits_v / cute::max(gmem_prob_stride[j], 16)), 8); @@ -768,14 +898,64 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM } }); - #if 0 - print("tma_gbasis : "); print(gmem_stride_bases); print("\n"); - #endif +#if 0 + print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n"); +#endif - using AuxParams = AuxTmaParams; - return cute::make_tuple(tma_desc, AuxParams{gmem_stride_bases}); + return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride}); +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_atom(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + uint32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + // + // TMA truncated layout + // + + auto smem_swizzle = get_swizzle_portion(slayout); + auto smem_layout = get_nonswizzle_portion(slayout); + + auto tma_gbasis = detail::construct_tma_gbasis(gtensor, smem_layout, cta_v_map); + + // + // Construct the TMA Desc and the strides of the TMA Tensor + // + + auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, + tma_gbasis, + smem_swizzle, + num_multicast); + + // + // Construct the Copy_Traits + // + + constexpr int num_bits_per_tma = decltype(size(tma_gbasis))::value * sizeof_bits_v; + using Traits = Copy_Traits, decltype(aux_params)>; + using Atom = Copy_Atom; + + Traits tma_traits{tma_desc, aux_params}; + +#if 0 + print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n"); + print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n"); +#endif + + // Return the Copy_Atom + return Atom{tma_traits}; } // The "logical TMA tid" is a map from the CTA rank to its logical id @@ -790,122 +970,46 @@ template CUTE_HOST_RTC auto -make_tma_copy_tiled(CopyOp, +make_tma_copy_tiled(CopyOp const& copy_op, Tensor const& gtensor, // Full GMEM Tensor SLayout const& slayout, // CTA Tile of SMEM Layout const& cta_t_map, // T: CTA thr idx -> logical TMA tid Layout const& cta_v_map) // V: CTA val idx -> gmem mode { - // - // TMA parameter checking - // - - CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), - "TMA requires CTA_Tile and SLayout top-level shape equivalence."); - CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, - "Number of active CTAs in TMA must divide domain size of slayout."); - - // - // TMA slayout manipulation - // - - // Invert the smem to get the largest contiguous vector in the smem layout - // smem idx -> smem coord - auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); - - // Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode - // smem idx -> gmem mode - auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); - -#if 0 - print("g_tensor : "); print(gtensor); print("\n"); - print("s_layout : "); print(slayout); print("\n"); - print("cta_t_map : "); print(cta_t_map); print("\n"); - print("cta_v_map : "); print(cta_v_map); print("\n"); - print("inv_s_layout : "); print(inv_smem_layout); print("\n"); - print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); -#endif - - // - // TMA gtensor manipulation - // - - // Generate a TupleBasis for the gtensor - // gmem coord -> gmem coord - auto glayout_basis = make_identity_layout(shape(gtensor)); - - // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc - // smem idx -> gmem coord - auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); - - // Truncate any incompatibilities -- no starting in the middle of gmodes - auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { - [[maybe_unused]] auto v = basis_value(e); - return not is_constant<1,decltype(v)>{}; - }); - static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); - - // Keep only the static-1 basis modes into gmem - auto tma_layout_trunc = take<0,smem_rank>(tma_layout_full); - - // Keep only the portion each multicast CTA will be responsible for - auto tma_layout_v = composition(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map))); - -#if 0 - print("glayout_basis : "); print(glayout_basis); print("\n"); - print("tma_layout_full : "); print(tma_layout_full); print("\n"); - - print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); - print("tma_layout_v : "); print(tma_layout_v); print("\n"); -#endif - - // - // Construct the TMA Desc and the strides of the TMA Tensor - // - - auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, - tma_layout_v, - get_swizzle_portion(slayout)); - - // - // Construct the Copy_Traits - // - - using T = typename GEngine::value_type; - constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof_bits_v; - using Traits = Copy_Traits, decltype(aux_params)>; - using Atom = Copy_Atom; - - Traits tma_traits{tma_desc, aux_params}; - -#if 0 - print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n"); - print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n"); -#endif + Copy_Atom atom = make_tma_copy_atom(copy_op, gtensor, slayout, + cosize(cta_t_map), cta_v_map); // // Construct the TiledCopy // - auto cta_tiler = product_each(shape(cta_v_map)); + [[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map)); + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / Int>{}; + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); // CTA V -> smem_coord - auto layout_v = composition(inv_smem_layout, size(tma_layout_trunc)); + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); // CTA T -> smem idx - auto layout_t = make_layout(cosize(cta_t_map), shape_div(size(tma_layout_trunc), cosize(cta_t_map))); + auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); // CTA TID -> smem coord auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); // Combine with the T mapping - auto layout_TV = make_layout(layout_T, layout_V); + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); #if 0 print("cta_tiler : "); print(cta_tiler); print("\n"); - print("layout_VT : "); print(layout_VT); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); print("layout_TV : "); print(layout_TV); print("\n"); #endif - return TiledCopy{tma_traits}; + return TiledCopy{atom}; } } // end namespace detail @@ -982,7 +1086,7 @@ make_tma_copy_tiled(CopyOp, copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params */ -template (copy_op, - gtensor, - slayout, - cta_t_tile, - cta_v_tile); + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, + gtensor, slayout, + cta_t_tile, cta_v_tile); } // Explicit defaulting -template -CUTE_HOST_RTC -auto -make_tma_copy(CopyOp const& copy_op, - Tensor const& gtensor, - SLayout const& slayout, - CTA_Tile const& cta_tile, - Cluster_Size const& cluster_size) -{ - using TmaInternalType = typename GEngine::value_type; - return make_tma_copy(copy_op, - gtensor, - slayout, - cta_tile, - cluster_size); -} - template @@ -1039,6 +1122,7 @@ make_tma_copy(CopyOp const& copy_op, return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); } +// Explicit defaulting template +CUTE_HOST_RTC +auto +make_tma_atom(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, + gtensor, slayout, + size(cluster_size), cta_v_tile); +} + +// The "VectorCopy Partitioner" for TMA +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + CtaCoord const& cta_coord, + Layout const& cta_layout, // T: CTA coord -> logical multicast id + Tensor const& stensor, // SMEM Tensor (TMATile, Iter) + Tensor const& gtensor) // GMEM Tensor (TMATile, Iter) +{ + // Invert the smem to get the largest contiguous vector in the smem layout + Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor))); + // Scale that up to cover all of the smem_coords + Layout layout_v = tile_to_shape(make_layout(inv_smem_layout), size<0>(stensor)); + + // Factor out the single-instrucion portion + Layout tma_layout_v = make_layout(Int::NumValSrc>{}); + Layout layout_V = logical_divide(layout_v, tma_layout_v); + + // Transform tile mode and coalesce + Tensor gtensor_v = coalesce(gtensor.compose(layout_V, _), Shape,_1>{}); // ((TMA,TMA_Iter),Iter) + Tensor stensor_v = coalesce(stensor.compose(layout_V, _), Shape,_1>{}); // ((TMA,TMA_Iter),Iter) + +#if 0 + if (thread0()) { + print("layout_V : "); print(layout_V); print("\n"); + print("gtensor_v : "); print(gtensor_v); print("\n"); + print("stensor_v : "); print(stensor_v); print("\n"); + } +#endif + + // Restride the cta-into-tma-instr layout + Layout tma_layout_t = composition(make_layout(Int<1>{}, shape_div(size(tma_layout_v), cosize(cta_layout))), cta_layout); + Layout tma_layout_tv = make_layout(tma_layout_t, tma_layout_v); + + // Transform TMA mode + Tensor gtensor_tv = gtensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter) + Tensor stensor_tv = stensor_v.compose(make_tile(tma_layout_tv, _), _); // (((Thr,Frg),TMA_Iter),Iter) + +#if 0 + if (thread0()) { + print("tma_layout_tv : "); print(tma_layout_tv); print("\n"); + print("gtensor_tv : "); print(gtensor_tv); print("\n"); + print("stensor_tv : "); print(stensor_tv); print("\n"); + } +#endif + + // Slice and group Frg,TMA_Iter and return + auto c = make_coord(make_coord(make_coord(cta_coord, _), _), _); + return cute::make_tuple(group_modes<0,2>(gtensor_tv(c)), group_modes<0,2>(stensor_tv(c))); +} + } // end namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 80468914..27456f38 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -31,12 +31,12 @@ #pragma once #include -#include -#include +#include #include +#include #include namespace cute { @@ -196,39 +196,37 @@ struct MMA_Atom> template struct ThrMMA; +// @tparam MMA_Atom The MMA_Atom to use in the TiledMMA +// @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed. +// @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom. template >, - class ValLayoutMNK = Layout>, - class PermutationsMNK = Tile> + class AtomLayoutMNK, + class PermutationMNK = Tile> struct TiledMMA : MMA_Atom { - static_assert(rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); - static_assert(rank_v == 3, "TiledMMA requires rank-3 ValLayoutMNK"); - static_assert(rank_v == 3, "TiledMMA requires rank-3 PermutationsMNK"); - + using Atom = MMA_Atom; using AtomShape_MNK = typename MMA_Atom::Shape_MNK; - + using AtomThrID = typename MMA_Atom::ThrID; using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; - // ThrV -> thread_idx - using AtomThrID = typename MMA_Atom::ThrID; + static_assert( rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); + static_assert( rank_v == 3, "TiledMMA requires rank-3 PermutationMNK"); + static_assert( is_tile::value, "TiledMMA requires independent permutations of MNK."); + static_assert(is_static::value, "TiledMMA requires static permutations of MNK."); - // (M,N,K) - using TiledShape_MNK = decltype(make_shape(size<0>(AtomShape_MNK{})*size<0>(AtomLayoutMNK{})*size<0>(ValLayoutMNK{}), - size<1>(AtomShape_MNK{})*size<1>(AtomLayoutMNK{})*size<1>(ValLayoutMNK{}), - size<2>(AtomShape_MNK{})*size<2>(AtomLayoutMNK{})*size<2>(ValLayoutMNK{}))); - - // thrid = (ThrV,ThrM,ThrN,ThrK) -> thr_idx using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); + ThrLayoutVMNK thr_layout_vmnk_; - // thr_idx -> (ThrV,ThrM,ThrN,ThrK) - using TidLayout = decltype(right_inverse(ThrLayoutVMNK{})); + CUTE_HOST_DEVICE constexpr + TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {}) + : MMA_Atom(mma_atom), + thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {} CUTE_HOST_DEVICE constexpr auto get_thr_layout_vmnk() const { - return ThrLayoutVMNK{}; + return thr_layout_vmnk_; } // Tile a tensor or a layout from shape @@ -243,17 +241,17 @@ struct TiledMMA : MMA_Atom // RestM: The values tiled in M. // RestN: The values tiled in N. template - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - thrfrg_C(CTensor&& ctensor) + thrfrg_C(CTensor&& ctensor) const { CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); - CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<0>(ctensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(ctensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); // Reorder the tensor for the TiledAtom - auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), - left_inverse(get<1>(PermutationsMNK{}))); + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<1>(PermutationMNK{})); auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) // Tile the tensor for the Atom @@ -266,25 +264,13 @@ struct TiledMMA : MMA_Atom // Tile the tensor for the C-threads auto thr_tile = make_tile(_, - make_tile(make_layout(size<1>(ThrLayoutVMNK{})), - make_layout(size<2>(ThrLayoutVMNK{})))); + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<2>(thr_layout_vmnk_)))); auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) return thr_tensor; } - // Tile from (M,N,...) - // to (thr_idx,(FrgV,(RestM,RestN,...))) - template - CUTE_HOST_DEVICE constexpr static - auto - tidfrg_C(CTensor&& ctensor) - { - // Don't need a ctile composition because ThrK is last mode in TidLayout - - return thrfrg_C(ctensor).compose(TidLayout{}, _); - } - // Tile a tensor or a layout from shape // (M,K,...) // to shape @@ -297,17 +283,17 @@ struct TiledMMA : MMA_Atom // RestM: The values tiled in M. // RestK: The values tiled in K. template - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - thrfrg_A(ATensor&& atensor) + thrfrg_A(ATensor&& atensor) const { CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); //CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - //UTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); // Reorder the tensor for the TiledAtom - auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), - left_inverse(get<2>(PermutationsMNK{}))); + auto t_tile = make_tile(get<0>(PermutationMNK{}), + get<2>(PermutationMNK{})); auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) // Tile the tensor for the Atom @@ -320,29 +306,13 @@ struct TiledMMA : MMA_Atom // Tile the tensor for the Thread auto thr_tile = make_tile(_, - make_tile(make_layout(size<1>(ThrLayoutVMNK{})), - make_layout(size<3>(ThrLayoutVMNK{})))); + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) return thr_tensor; } - // Tile from (M,K,...) - // to (thr_idx,(FrgV,(RestM,RestK,...))) - template - CUTE_HOST_DEVICE constexpr static - auto - tidfrg_A(ATensor&& atensor) - { - auto atile = make_tile(_, - make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), - make_stride( Int<1>{} , Int<0>{} )), - _)); - // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) - - return thrfrg_A(atensor).compose(atile, _).compose(TidLayout{}, _); - } - // Tile a tensor or a layout from shape // (N,K,...) // to shape @@ -355,17 +325,17 @@ struct TiledMMA : MMA_Atom // RestN: The values tiled in N. // RestK: The values tiled in K. template - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - thrfrg_B(BTensor&& btensor) + thrfrg_B(BTensor&& btensor) const { CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); //CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); //CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); // Reorder the tensor for the TiledAtom - auto t_tile = make_tile(left_inverse(get<1>(PermutationsMNK{})), - left_inverse(get<2>(PermutationsMNK{}))); + auto t_tile = make_tile(get<1>(PermutationMNK{}), + get<2>(PermutationMNK{})); auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) // Tile the tensor for the Atom @@ -378,44 +348,28 @@ struct TiledMMA : MMA_Atom // Tile the tensor for the Thread auto thr_tile = make_tile(_, - make_tile(make_layout(size<2>(ThrLayoutVMNK{})), - make_layout(size<3>(ThrLayoutVMNK{})))); + make_tile(make_layout(size<2>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) return thr_tensor; } - // Tile from (N,K,...) - // to (thr_idx,(FrgV,(RestN,RestK,...))) - template - CUTE_HOST_DEVICE constexpr static + template ::value)> + CUTE_HOST_DEVICE constexpr auto - tidfrg_B(BTensor&& btensor) + get_slice(ThrIdx const& thr_idx) const { - auto btile = make_tile(_, - make_tile(make_layout(make_shape (size<1>(ThrLayoutVMNK{}), size<2>(ThrLayoutVMNK{})), - make_stride( Int<0>{} , Int<1>{} )), - _)); - // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) - - return thrfrg_B(btensor).compose(btile, _).compose(TidLayout{}, _); + auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx); + return ThrMMA{*this, thr_vmnk}; } template ::value)> - CUTE_HOST_DEVICE static constexpr + CUTE_HOST_DEVICE constexpr auto - get_slice(ThrIdx const& thr_idx) - { - auto thr_vmnk = ThrLayoutVMNK{}.get_flat_coord(thr_idx); - return ThrMMA(thr_vmnk); - } - - template ::value)> - CUTE_HOST_DEVICE static constexpr - auto - get_thread_slice(ThrIdx const& thr_idx) + get_thread_slice(ThrIdx const& thr_idx) const { return get_slice(thr_idx); } @@ -424,104 +378,144 @@ struct TiledMMA : MMA_Atom // Utility for printing and visualization // - CUTE_HOST_DEVICE constexpr static + // The size of the MNK-mode + template + CUTE_HOST_DEVICE constexpr auto - get_layoutC_MN() + tile_size_mnk() const { + static_assert(0 <= I && I < 3); + auto core_size = size(AtomShape_MNK{}) * size(get_thr_layout_vmnk()); + [[maybe_unused]] auto perm_size = size(PermutationMNK{}); + if constexpr (is_underscore::value) { + return core_size; + } else { + return cute::max(core_size, perm_size); + } + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutC_MN() const { // (M,N) -> (M,N) - auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); // (cthrid,val) -> (M,N) auto layoutC_TV = thrfrg_C(ref_C); // (M,N) -> (cthrid,frg) auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); // cthrid = (v,m,n) -> thr_idx - auto thrID_C = ThrLayoutVMNK{}(_,_,_,Int<0>{}); + auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{}); return cute::make_tuple(layoutC_MN, thrID_C); } - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - get_layoutC_TV() + get_layoutC_TV() const { // (M,N) -> (M,N) - auto ref_C = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<1>(TiledShape_MNK{}))); + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); + // (cthrid,val) -> (M,N) + auto layoutC_TV = thrfrg_C(ref_C); - return tidfrg_C(ref_C); + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (M,N) + return layoutC_TV.compose(thridx_2_thrid, _); } - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - get_layoutA_MK() + get_layoutA_MK() const { // (M,K) -> (M,K) - auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); // (athrid,val) -> (M,K) auto layoutA_TV = thrfrg_A(ref_A); // (M,K) -> (athrid,frg) auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); // athrid = (v,m,k) -> thr_idx - auto thrID_A = ThrLayoutVMNK{}(_,_,Int<0>{},_); + auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_); return cute::make_tuple(layoutA_MK, thrID_A); } - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - get_layoutA_TV() + get_layoutA_TV() const { // (M,K) -> (M,K) - auto ref_A = make_layout(make_shape(size<0>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); + // (athrid,val) -> (M,K) + auto layoutA_TV = thrfrg_A(ref_A); - return tidfrg_A(ref_A); + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _); } - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - get_layoutB_NK() + get_layoutB_NK() const { // (N,K) -> (N,K) - auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); // (bthrid,val) -> (N,K) auto layoutB_TV = thrfrg_B(ref_B); // (N,K) -> (bthrid,frg) auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); // bthrid = (v,n,k) -> thr_idx - auto thrID_B = ThrLayoutVMNK{}(_,Int<0>{},_,_); + auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_); return cute::make_tuple(layoutB_NK, thrID_B); } - CUTE_HOST_DEVICE constexpr static + CUTE_HOST_DEVICE constexpr auto - get_layoutB_TV() + get_layoutB_TV() const { // (N,K) -> (N,K) - auto ref_B = make_layout(make_shape(size<1>(TiledShape_MNK{}), size<2>(TiledShape_MNK{}))); + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); + // (bthrid,val) -> (N,K) + auto layoutB_TV = thrfrg_B(ref_B); - return tidfrg_B(ref_B); + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (N,K) + return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _); } }; template struct ThrMMA : TiledMMA { - // Use ThrVMNK and thrfrg rather than thr_idx and tidfrg - // to support swizzled threads partitioning dynamic layouts ThrVMNK thr_vmnk_; - CUTE_HOST_DEVICE constexpr - ThrMMA(ThrVMNK const& thr_vmnk) : thr_vmnk_(thr_vmnk) {} - template CUTE_HOST_DEVICE constexpr auto partition_C(CTensor&& ctensor) const { - auto thr_tensor = make_tensor(std::forward(ctensor).data(), TiledMMA::thrfrg_C(ctensor.layout())); + auto thr_tensor = make_tensor(std::forward(ctensor).data(), this->thrfrg_C(ctensor.layout())); auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); @@ -532,7 +526,7 @@ struct ThrMMA : TiledMMA auto partition_A(ATensor&& atensor) const { - auto thr_tensor = make_tensor(std::forward(atensor).data(), TiledMMA::thrfrg_A(atensor.layout())); + auto thr_tensor = make_tensor(std::forward(atensor).data(), this->thrfrg_A(atensor.layout())); auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); @@ -543,7 +537,7 @@ struct ThrMMA : TiledMMA auto partition_B(BTensor&& btensor) const { - auto thr_tensor = make_tensor(std::forward(btensor).data(), TiledMMA::thrfrg_B(btensor.layout())); + auto thr_tensor = make_tensor(std::forward(btensor).data(), this->thrfrg_B(btensor.layout())); auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); @@ -580,38 +574,32 @@ struct ThrMMA : TiledMMA template >, - class MMAValLayout = Layout>, class Permutations = Tile> CUTE_HOST_DEVICE constexpr auto -make_tiled_mma(MMA_Atom const&, +make_tiled_mma(MMA_Atom const& mma_atom, MMAThrLayout const& thr_layout = {}, - MMAValLayout const& val_layout = {}, Permutations const& permutations = {}) { auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{}); - auto val_layout_mnk = append<3>(val_layout, Layout<_1,_0>{}); auto permutation_mnk = append<3>(permutations, _); return TiledMMA, decltype(thr_layout_mnk), - decltype(val_layout_mnk), - decltype(permutation_mnk)>{}; + decltype(permutation_mnk)>{mma_atom, thr_layout_mnk}; } template >, - class MMAValLayout = Layout>, class Permutations = Tile> CUTE_HOST_DEVICE constexpr auto make_tiled_mma(MMA_Op const&, MMAThrLayout const& thr_layout = {}, - MMAValLayout const& val_layout = {}, Permutations const& permutations = {}) { // Attempt to wrap in an MMA_Atom<> and forward - return make_tiled_mma(MMA_Atom{}, thr_layout, val_layout, permutations); + return make_tiled_mma(MMA_Atom{}, thr_layout, permutations); } // @@ -680,28 +668,38 @@ partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) // Size // -template +template CUTE_HOST_DEVICE constexpr auto tile_size(TiledMMA const& mma) { - return size(typename TiledMMA::TiledShape_MNK{}); + return mma.template tile_size_mnk(); } -template +template CUTE_HOST_DEVICE constexpr auto tile_shape(TiledMMA const& mma) { - return shape(typename TiledMMA::TiledShape_MNK{}); + return make_shape(tile_size<0>(mma), tile_size<1>(mma), tile_size<2>(mma)); } +// Deprecate? template CUTE_HOST_DEVICE constexpr auto size(TiledMMA const& mma) { - return size(typename TiledMMA::ThrLayoutVMNK{}); + return size(mma.get_thr_layout_vmnk()); +} + +// Alias +template +CUTE_HOST_DEVICE constexpr +auto +thr_size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); } // @@ -715,33 +713,31 @@ print(MMA_Atom> const&) { using Atom = MMA_Atom>; print("MMA_Atom\n"); - print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); - print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); - print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); - print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); + print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); + print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); } -template +template CUTE_HOST_DEVICE void -print(TiledMMA const& mma) +print(TiledMMA const& mma) { - using MMA = TiledMMA; print("TiledMMA\n"); - print(" TiledThr: "); print(TiledThr{}); print("\n"); - print(" TiledVal: "); print(TiledVal{}); print("\n"); - print(" TiledPerm: "); print(TiledPerm{}); print("\n"); - print(" TiledShape_MNK: "); print(typename MMA::TiledShape_MNK{}); print("\n"); - print(" ThrLayoutVMNK: "); print(typename MMA::ThrLayoutVMNK{}); print("\n"); + print(" ThrLayoutVMNK: "); print(mma.get_thr_layout_vmnk()); print("\n"); + print(" PermutationMNK: "); print(TiledPerm{}); print("\n"); print(static_cast(mma)); } template CUTE_HOST_DEVICE void -print(ThrMMA const&) +print(ThrMMA const& thr_mma) { - print(TiledMMA{}); + print("ThrMMA\n"); + print(" Thr VMNK: "); print(thr_mma.thr_vmnk_); print("\n"); + print(static_cast(thr_mma)); } template @@ -766,18 +762,6 @@ print_latex(TiledMMA const& mma) layoutB_NK, thrID_B); } -// EXPERIMENTAL -- Doesn't work with Swizzled Thr TileMMAs... -template -CUTE_HOST_DEVICE -auto -print_latex_2(TiledMMA const& mma) -{ - print_latex_mma(typename TiledMMA::TiledShape_MNK{}, - mma.get_layoutC_TV(), - mma.get_layoutA_TV(), - mma.get_layoutB_TV()); -} - // MNK MMA Layout to console printer -- 8-value color coded by thread template (tid,vid) and printf(latex_footer); } -// ThrVal MMA Layout to Latex TIKZ -- 8-value color coded by thread -template -CUTE_HOST_DEVICE -void -print_latex_mma(Shape_MNK const& shape_mnk, - LayoutC const& C, // (thr_idx,vid) -> (m,n) - LayoutA const& A, // (thr_idx,vid) -> (m,k) - LayoutB const& B) // (thr_idx,vid) -> (n,k) -{ - CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); - CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); - - char const* latex_header = - "\\documentclass{standalone}\n" - "\\usepackage{tikz}\n" - "\\usetikzlibrary{external}\n" - "\\tikzexternalize\n" - "\\begin{document}\n" - "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; - char const* latex_footer = - "\\end{tikzpicture}\n" - "\\end{document}\n"; - - char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", - "{rgb,255:red,175;green,255;blue,175}", - "{rgb,255:red,255;green,255;blue,175}", - "{rgb,255:red,255;green,175;blue,175}", - "{rgb,255:red,210;green,210;blue,255}", - "{rgb,255:red,210;green,255;blue,210}", - "{rgb,255:red,255;green,255;blue,210}", - "{rgb,255:red,255;green,210;blue,210}"}; - - // Header - printf("%% Shape_MNK: "); print(shape_mnk); printf("\n"); - printf("%% LayoutC : "); print(C); printf("\n"); - printf("%% LayoutA : "); print(A); printf("\n"); - printf("%% LayoutB : "); print(B); printf("\n\n"); - - printf(latex_header); - - auto M = size<0>(shape_mnk); - auto N = size<1>(shape_mnk); - auto K = size<2>(shape_mnk); - - // C starting at 0,0 - bool c_filled[M][N] = {}; - for (int t = 0; t < size<0>(C); ++t) { - for (int v = 0; v < size<1>(C); ++v) { - int m = C(t,v) % M; - int n = C(t,v) / M; - - if (not c_filled[m][n]) { - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[t % 8], - m, n, - t, v); - c_filled[m][n] = true; - } - } - } - - // A starting at 0,-size<1>(A)-1 - bool a_filled[M][K] = {}; - for (int t = 0; t < size<0>(A); ++t) { - for (int v = 0; v < size<1>(A); ++v) { - int m = A(t,v) % M; - int k = A(t,v) / M; - - if (not a_filled[m][k]) { - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[t % 8], - m, k - 1 - K, - t, v); - a_filled[m][k] = true; - } - } - } - - // B starting at -size<1>(B)-1,0 - bool b_filled[N][K] = {}; - for (int t = 0; t < size<0>(B); ++t) { - for (int v = 0; v < size<1>(B); ++v) { - int n = B(t,v) % N; - int k = B(t,v) / N; - - if (not b_filled[n][k]) { - printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", - color_map[t % 8], - k - 1 - K, n, - t, v); - b_filled[n][k] = true; - } - } - } - - // A labels - for (int m = 0, k = -1; m < M; ++m) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, m); - } - for (int k = 0, m = -1; k < K; ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k - 1 - K, k); - } - // B labels - for (int n = 0, k = -1; n < N; ++n) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, n); - } - for (int k = 0, n = -1; k < K; ++k) { - printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k - 1 - K, n, k); - } - - // Footer - printf(latex_footer); -} - } // namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/config.hpp b/include/cute/config.hpp index e4bda683..0cea8b15 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -47,7 +47,7 @@ #endif #if !defined(__CUDACC_RTC__) && !defined(__clang__) && \ - (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) + (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) # define CUTE_UNROLL #pragma unroll # define CUTE_NO_UNROLL #pragma unroll 1 #elif defined(__CUDACC_RTC__) || defined(__clang__) @@ -120,6 +120,8 @@ #include #endif +#define CUTE_STATIC_V(x) decltype(x)::value + #define CUTE_STATIC_ASSERT static_assert #define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 88e7abf6..135c2980 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -91,7 +91,7 @@ private: // Flag for fast branching on straddled elements static constexpr bool is_storage_unaligned = ((sizeof_bits_v % sizeof_bits_v) != 0); - friend class subbyte_iterator; + friend struct subbyte_iterator; // Pointer to storage element storage_type* ptr_ = nullptr; @@ -208,7 +208,7 @@ struct subbyte_iterator private: - template friend class swizzle_ptr; + template friend struct swizzle_ptr; // Pointer to storage element storage_type* ptr_ = nullptr; diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index d7a59b88..2522d6b6 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -327,7 +327,7 @@ ceil_div(IntTupleA const& a, IntTupleB const& b) { if constexpr (is_tuple::value && is_tuple::value) { static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); - constexpr int R = tuple_size::value; // Missing ranks in TupleB are implictly 1 + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); } else { return (a + b - Int<1>{}) / b; @@ -336,6 +336,28 @@ ceil_div(IntTupleA const& a, IntTupleB const& b) CUTE_GCC_UNREACHABLE; } +// +// round_up +// Round @a a up to the nearest multiple of @a b. +// For negative numbers, rounds away from zero. +// + +template +CUTE_HOST_DEVICE constexpr +auto +round_up(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); }); + } else { + return ((a + b - Int<1>{}) / b) * b; + } + + CUTE_GCC_UNREACHABLE; +} + /** Division for Shapes * Case Tuple Tuple: * Perform shape_div element-wise @@ -429,6 +451,7 @@ template using is_congruent = decltype(congruent(declval(), declval())); /** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division) + * weakly_congruent is a partial order on A and B: A <= B */ template CUTE_HOST_DEVICE constexpr @@ -458,7 +481,7 @@ using is_weakly_congruent = decltype(weakly_congruent(declval(), declval() /** Test if Shape B is compatible with Shape A: * Any coordinate into A can also be used as a coordinate into B - * A <= B is a partially ordered set of factored shapes + * compatible is a partial order on A and B: A <= B */ template CUTE_HOST_DEVICE constexpr @@ -487,7 +510,8 @@ template using is_compatible = decltype(compatible(declval(), declval())); /** Test if Shape B is weakly compatible with Shape A: - * Shape B divides Shape A at some level of refinement + * Shape B is a multiple of a shape that is compatible with Shape A + * weakly_compatible is a partial order on A and B: A <= B */ template CUTE_HOST_DEVICE constexpr @@ -502,7 +526,7 @@ weakly_compatible(IntTupleA const& a, IntTupleB const& b) [](auto const&... z) { return (true_type{} && ... && z); }); } } else if constexpr (is_integral::value) { - return a % size(b) == Int<0>{}; + return size(b) % a == Int<0>{}; } else if constexpr (is_integral::value) { return false_type{}; } else { diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index fb30e4f3..2d01fd51 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -981,7 +981,6 @@ auto composition(Layout const& lhs, Layout const& rhs) { - //return detail::composition_impl(flatten(lhs), rhs.shape(), rhs.stride()); return detail::composition_impl(lhs, rhs.shape(), rhs.stride()); } @@ -997,8 +996,8 @@ composition(Layout const& lhs, return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); } else if constexpr (is_underscore::value) { return lhs; - } else { - return composition(lhs, make_layout(rhs)); + } else if constexpr (is_integral::value) { + return detail::composition_impl(lhs, rhs, Int<1>{}); } CUTE_GCC_UNREACHABLE; @@ -1097,15 +1096,18 @@ inverse_seq(Shape const& shape, Stride const& stride, seq) auto next_I = cute::find_if(stride, [](auto a) { return is_constant{}; }); if constexpr (next_I == decltype(rank(stride))::value) { + // If not found, return current seq return seq{}; } else { // auto next_stride = get(shape) * get(stride); // NOTE: Needed for g++-7 using next_stride = decltype(get(shape) * get(stride)); - if constexpr (is_static::value) { + if constexpr (is_static::value && !is_constant::value) { + // If next_stride is static and unique, then continue return inverse_seq(shape, stride, seq{}); } else { + // Else return current seq + next_I return seq{}; } } @@ -1340,28 +1342,24 @@ template const& layout, - Layout const& tile) + Layout const& tiler) { - //CUTE_STATIC_ASSERT_V(size(layout) % size(tile) == Int<0>{}, - // "Tiling does not evenly divide the block"); - // NOTE: With tiles that have stride-0, this doesn't have to be true - - return composition(layout, make_layout(tile, complement(tile, size(layout)))); + return composition(layout, make_layout(tiler, complement(tiler, size(layout)))); } -template +template CUTE_HOST_DEVICE constexpr auto logical_divide(Layout const& layout, - IntTuple const& tile) + Tiler const& tiler) { - if constexpr (is_tuple::value) { - static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tile."); - return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_divide(l,t); }); - } else if constexpr (is_underscore::value) { + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tiler."); + return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); }); + } else if constexpr (is_underscore::value) { return layout; - } else if constexpr (is_integral::value) { - return logical_divide(layout, make_layout(tile)); + } else if constexpr (is_integral::value) { + return logical_divide(layout, make_layout(tiler)); } CUTE_GCC_UNREACHABLE; @@ -1374,24 +1372,24 @@ logical_divide(Layout const& layout, // template + class Tiler> CUTE_HOST_DEVICE constexpr auto zipped_divide(Layout const& layout, - Tile const& tile) + Tiler const& tiler) { - return tile_unzip(logical_divide(layout, tile), tile); + return tile_unzip(logical_divide(layout, tiler), tiler); } // Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) template + class Tiler> CUTE_HOST_DEVICE constexpr auto tiled_divide(Layout const& layout, - Tile const& tile) + Tiler const& tiler) { - auto div = zipped_divide(layout, tile); + auto div = zipped_divide(layout, tiler); auto R = rank<1>(div); return div(_, repeat(_)); @@ -1399,13 +1397,13 @@ tiled_divide(Layout const& layout, // Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) template + class Tiler> CUTE_HOST_DEVICE constexpr auto flat_divide(Layout const& layout, - Tile const& tile) + Tiler const& tiler) { - auto div = zipped_divide(layout, tile); + auto div = zipped_divide(layout, tiler); auto R0 = rank<0>(div); auto R1 = rank<1>(div); @@ -1421,24 +1419,24 @@ template const& layout, - Layout const& tile) + Layout const& tiler) { - return make_layout(layout, composition(complement(layout, size(layout)*cosize(tile)), tile)); + return make_layout(layout, composition(complement(layout, size(layout)*cosize(tiler)), tiler)); } -template +template CUTE_HOST_DEVICE constexpr auto logical_product(Layout const& layout, - IntTuple const& tile) + Tiler const& tiler) { - if constexpr (is_tuple::value) { - static_assert(tuple_size::value <= Layout::rank); - return transform_layout(layout, tile, [](auto const& l, auto const& t) { return logical_product(l,t); }); - } else if constexpr (is_underscore::value) { + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_product: Too many modes in tiler."); + return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); }); + } else if constexpr (is_underscore::value) { return layout; - } else if constexpr (is_integral::value) { - return logical_product(layout, make_layout(tile)); + } else if constexpr (is_integral::value) { + return logical_product(layout, make_layout(tiler)); } CUTE_GCC_UNREACHABLE; @@ -1451,45 +1449,43 @@ logical_product(Layout const& layout, // template + class Tiler> CUTE_HOST_DEVICE constexpr auto zipped_product(Layout const& layout, - Tile const& tile) + Tiler const& tiler) { - return tile_unzip(logical_product(layout, tile), tile); + return tile_unzip(logical_product(layout, tiler), tiler); } // Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) template + class Tiler> CUTE_HOST_DEVICE constexpr auto tiled_product(Layout const& layout, - Tile const& tile) + Tiler const& tiler) { - auto div = zipped_product(layout, tile); + auto div = zipped_product(layout, tiler); - auto R = rank(tile); + auto R = rank<1>(div); return div(_, repeat(_)); } -// Attempts to reproduce layout "block" over layout "layout" -// That is, think of every element of "layout" as a "block" +// Attempts to reproduce a layout over a tiler +// That is, think of every element of "tiler" as a "layout" // and return the layout of the resulting structure template CUTE_HOST_DEVICE constexpr auto -blocked_product(Layout const& block, - Layout const& layout) +blocked_product(Layout const& layout, + Layout const& tiler) { constexpr int R = cute::max(rank_v, rank_v); - auto padded_block = append(block); - auto padded_layout = append(layout); - - auto result = logical_product(padded_block, padded_layout); + auto result = logical_product(append(layout), append(tiler)); + return coalesce(zip(get<0>(result), get<1>(result)), repeat(Int<1>{})); } @@ -1497,14 +1493,12 @@ template CUTE_HOST_DEVICE constexpr auto -raked_product(Layout const& block, - Layout const& layout) +raked_product(Layout const& layout, + Layout const& tiler) { constexpr int R = cute::max(rank_v, rank_v); - auto padded_block = append(block); - auto padded_layout = append(layout); - auto result = logical_product(padded_block, padded_layout); + auto result = logical_product(append(layout), append(tiler)); return coalesce(zip(get<1>(result), get<0>(result)), repeat(Int<1>{})); } diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index b8765d4a..69e47182 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -473,6 +473,16 @@ zipped_divide(ComposedLayout const& a, return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); } +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b)); +} + template CUTE_HOST_DEVICE constexpr diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index ac6ff539..0b248fad 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -181,7 +181,7 @@ template CUTE_HOST_DEVICE constexpr auto make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) { - return make_tuple_iter(cute::make_tuple(t0, t1, ts...)); + return make_inttuple_iter(cute::make_tuple(t0, t1, ts...)); } // diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index e2c1e0ae..bd548a46 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -148,7 +148,9 @@ using _96 = Int<96>; using _128 = Int<128>; using _192 = Int<192>; using _256 = Int<256>; +using _384 = Int<384>; using _512 = Int<512>; +using _768 = Int<768>; using _1024 = Int<1024>; using _2048 = Int<2048>; using _4096 = Int<4096>; diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index 82f4b972..f847594f 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -97,7 +97,7 @@ template ::value && is_std_integral::value)> CUTE_HOST_DEVICE constexpr -auto +cute::common_type_t gcd(T t, U u) { while (true) { if (t == 0) { return u; } @@ -112,7 +112,7 @@ template ::value && is_std_integral::value)> CUTE_HOST_DEVICE constexpr -auto +cute::common_type_t lcm(T const& t, U const& u) { return (t / gcd(t,u)) * u; } diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp index 75cdf8c2..ce951b7b 100644 --- a/include/cute/pointer_base.hpp +++ b/include/cute/pointer_base.hpp @@ -233,14 +233,14 @@ CUTE_HOST_DEVICE void print(T const* const ptr) template CUTE_HOST_DEVICE void print(counting_iterator ptr) { - printf("counting_iter_"); print(ptr.n_); + printf("counting_iter("); print(ptr.n_); printf(")"); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator ptr) { - return os << "counting_iter_" << ptr.n_; + return os << "counting_iter(" << ptr.n_ << ")"; } #endif // !defined(__CUDACC_RTC__) diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index c25ce1d2..938842fc 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -990,13 +990,10 @@ CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) { print(tensor); print(":\n"); - auto format = get_format(tensor(0)); - using type = typename decltype(format)::type; - if constexpr (Layout::rank == 1) { for (int m = 0; m < size(tensor); ++m) { - printf(format.format, format.digits, type(tensor(m))); + pretty_print(tensor(m)); printf("\n"); } } else @@ -1004,7 +1001,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) { for (int m = 0; m < size<0>(tensor); ++m) { for (int n = 0; n < size<1>(tensor); ++n) { - printf(format.format, format.digits, type(tensor(m,n))); + pretty_print(tensor(m,n)); } printf("\n"); } @@ -1013,7 +1010,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) { print_tensor(tensor(_,_,0)); for (int k = 1; k < size<2>(tensor); ++k) { - for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("-"); } print("\n"); + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); print_tensor(tensor(_,_,k)); } } else @@ -1021,7 +1018,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) { print_tensor(tensor(_,_,_,0)); for (int p = 1; p < size<3>(tensor); ++p) { - for (int i = 0; i < format.digits*size<1>(tensor); ++i) { print("="); } print("\n"); + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); print_tensor(tensor(_,_,_,p)); } } diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index 31dad07f..cf75e3dd 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -57,61 +57,6 @@ num_digits(int x) 10))))))))); } -template -struct format_and_size { - using type = T; - char const* format; - int digits; -}; - -CUTE_HOST_DEVICE -format_and_size -get_format(bool) { - return {"%*d", 3}; -} - -CUTE_HOST_DEVICE -format_and_size -get_format(int32_t) { - return {"%*d", 5}; -} - -CUTE_HOST_DEVICE -format_and_size -get_format(uint32_t) { - return {"%*d", 5}; -} - -CUTE_HOST_DEVICE -format_and_size -get_format(int64_t) { - return {"%*d", 5}; -} - -CUTE_HOST_DEVICE -format_and_size -get_format(uint64_t) { - return {"%*d", 5}; -} - -CUTE_HOST_DEVICE -format_and_size -get_format(half_t) { - return {"%*.2f", 8}; -} - -CUTE_HOST_DEVICE -format_and_size -get_format(float) { - return {"%*.2e", 10}; -} - -CUTE_HOST_DEVICE -format_and_size -get_format(double) { - return {"%*.3e", 11}; -} - // // print dispatcher // @@ -195,4 +140,54 @@ print(char const* format) { printf("%s", format); } +// +// pretty printing +// + +template +CUTE_HOST_DEVICE void +pretty_print(T const& v) { + printf(" "); print(v); +} + +CUTE_HOST_DEVICE void +pretty_print(bool const& v) { + printf("%*d", 3, int(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(int32_t const& v) { + printf("%*d", 5, v); +} + +CUTE_HOST_DEVICE void +pretty_print(uint32_t const& v) { + printf("%*d", 5, v); +} + +CUTE_HOST_DEVICE void +pretty_print(int64_t const& v) { + printf("%*lld", 5, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(uint64_t const& v) { + printf("%*llu", 5, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(half_t const& v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(float const& v) { + printf("%*.2e", 10, v); +} + +CUTE_HOST_DEVICE void +pretty_print(double const& v) { + printf("%*.3e", 11, v); +} + } // end namespace cute diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index d3c75eb5..947aefde 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -122,6 +122,9 @@ using CUTE_STL_NAMESPACE::is_empty_v; using CUTE_STL_NAMESPACE::invoke_result_t; +using CUTE_STL_NAMESPACE::common_type; +using CUTE_STL_NAMESPACE::common_type_t; + // using CUTE_STL_NAMESPACE::declval; diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index 3ef5e110..bc0dbbe3 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -47,6 +47,18 @@ namespace cutlass { namespace arch { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts +// This enum class specifies the NamedBarriers reserved by CUTLASS. +enum class ReservedNamedBarriers { + EpilogueBarrier = 0, + TransposeBarrier = 1, + TransformBarrier = 2, + StreamkBarrier0 = 3, + StreamkBarrier1 = 4 + , FirstUserBarrier = StreamkBarrier1 + 1 +}; + + class NamedBarrier { // Data Members: @@ -60,9 +72,19 @@ class NamedBarrier { public: + // Constructor for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) + : num_threads_(num_threads), id_(static_cast(reserved_named_barriers)) {} + + // Constructor for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount CUTLASS_DEVICE NamedBarrier(uint32_t num_threads, uint32_t id = 0) - : num_threads_(num_threads), id_(id) {} + : num_threads_(num_threads), id_(id + ReservedNamedBarrierCount) { + CUTLASS_ASSERT(id + ReservedNamedBarrierCount <= HardwareMaxNumNamedBarriers && "Effective barrier_id should not exceed 16."); + } CUTLASS_DEVICE void arrive_and_wait() const { @@ -80,8 +102,52 @@ class NamedBarrier { } // Static variants + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount CUTLASS_DEVICE static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { + arrive_and_wait_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_and_wait_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, uint32_t barrier_id) { + arrive_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void sync(uint32_t num_threads, uint32_t barrier_id) { + sync_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void sync(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + sync_internal(num_threads, static_cast(reserved_named_barriers)); + } + + private: + CUTLASS_DEVICE + static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); #elif defined(__CUDA_ARCH__) @@ -90,7 +156,7 @@ class NamedBarrier { } CUTLASS_DEVICE - static void arrive(uint32_t num_threads, uint32_t barrier_id) { + static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); #elif defined(__CUDA_ARCH__) @@ -99,9 +165,16 @@ class NamedBarrier { } CUTLASS_DEVICE - static void sync(uint32_t num_threads, uint32_t barrier_id) { - NamedBarrier::arrive_and_wait(num_threads, barrier_id); + static void sync_internal(uint32_t num_threads, uint32_t barrier_id) { + NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id); } + + public: + // Currently we reserve 8 NamedBarriers for CUTLASS' own use cases, + // while leaving the renaming for general users. + static const uint32_t ReservedNamedBarrierCount = static_cast(ReservedNamedBarriers::FirstUserBarrier); + static const uint32_t HardwareMaxNumNamedBarriers = 16; + }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index 7ec158b1..e0c02147 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -80,7 +80,7 @@ public: static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; /// Number of storage elements - static size_t const kStorageElements = N / kElementsPerStoredItem; + static size_t const kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; /// Number of logical elements static size_t const kElements = N; diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index 04c63af2..56d66b83 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -68,7 +68,7 @@ template < struct NamedBarrierSync { CUTLASS_DEVICE static void sync() { - cutlass::arch::NamedBarrier::sync(ThreadCount, BarrierId); + cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast(BarrierId)); } }; @@ -227,9 +227,9 @@ template < uint32_t MaxNumNamedBarriers = 16 > struct NamedBarrierManager { - static constexpr uint32_t HardwareMaxNumNamedBarriers = 16; - static_assert(MaxNumNamedBarriers <= HardwareMaxNumNamedBarriers); - static_assert(MaxNumNamedBarriers + Offset <= HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); + + static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers); + static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); // Number of threads participating in the barrier static constexpr uint32_t ThreadCount = ThreadCount_; diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index 9ddd96f6..57e44304 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -55,6 +55,7 @@ #include #endif +#include #include "cutlass/cutlass.h" namespace cutlass { @@ -83,6 +84,28 @@ struct alignas(2) bfloat16_t { return h; } +private: + struct from_32_bit_integer_t {}; + static constexpr from_32_bit_integer_t from_32_bit_integer{}; + + template + CUTLASS_HOST_DEVICE + explicit bfloat16_t(from_32_bit_integer_t, T x) { + static_assert(cutlass::platform::is_integral::value && sizeof(T) == 4, "Requires 32-bit integer"); + + float flt = static_cast(x); + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(flt); + #else + std::memcpy(&bits, &flt, sizeof(bits)); + #endif + + storage = uint16_t(bits >> 16); + } + +public: /// Default constructor bfloat16_t() = default; @@ -129,18 +152,10 @@ struct alignas(2) bfloat16_t { /// Integer conversion - round toward nearest CUTLASS_HOST_DEVICE - explicit bfloat16_t(int x) { - float flt = static_cast(x); - uint32_t bits; + explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {} - #if defined(__CUDA_ARCH__) - bits = reinterpret_cast(flt); - #else - std::memcpy(&bits, &flt, sizeof(bits)); - #endif - - storage = uint16_t(bits >> 16); - } + CUTLASS_HOST_DEVICE + explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {} /// Converts to float CUTLASS_HOST_DEVICE diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 21923641..28611d51 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -35,7 +35,6 @@ #pragma once -#include #include #include "cutlass/cutlass.h" #include "cutlass/trace.h" diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 519d6760..c26bc73f 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -28,6 +28,16 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. +*/ + #pragma once #include diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp new file mode 100644 index 00000000..c9960bc3 --- /dev/null +++ b/include/cutlass/cuda_host_adapter.hpp @@ -0,0 +1,147 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Interface betweeen a CUTLASS device-wide operator and CUDA. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" + +#include "cutlass/platform/platform.h" +#if ! defined(__CUDACC_RTC__) +#include +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Macro-level guard for CUDA Host Adapter +// +#if !defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) +#define CUTLASS_ENABLE_CUDA_HOST_ADAPTER false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and +/// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API. +struct CudaHostAdapter { + + /// Limit the number of kernels + static constexpr int32_t kMaximumKernelCount = 4; + + /// Maximum cluster size + static constexpr int MaxClusterSize = 32; + + // + // Data members + // + + /// Handles + void *kernel_handles[kMaximumKernelCount]; + int32_t kernel_count = 0; + + CudaHostAdapter() = default; + + /// Dtor + virtual ~CudaHostAdapter() {} + + /// Copy Ctor deleted + CudaHostAdapter(const CudaHostAdapter&) = delete; + + /// Copy Assignment deleted + CudaHostAdapter& operator=(const CudaHostAdapter&) = delete; + + /// Move ctor deleted + CudaHostAdapter(CudaHostAdapter&&) = delete; + + /// Move assignment deleted + CudaHostAdapter& operator=(CudaHostAdapter&&) = delete; + + + /// Ctor + inline CudaHostAdapter( + void **kernel_handles_, + int32_t kernel_count_ + ): + kernel_count(kernel_count_) + { + CUTLASS_ASSERT(kernel_count >= 0); + for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = kernel_handles_[i]; + } + } + + /// Queries the occupancy of a kernel + virtual Status query_occupancy( + int32_t *device_sms, + int32_t *sm_occupancy, + int32_t kernel_index, + int32_t thread_count, + int32_t smem_size) = 0; + + /// Launches a kernel without using Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) = 0; + + /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) = 0; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp new file mode 100644 index 00000000..b0df53f7 --- /dev/null +++ b/include/cutlass/detail/collective.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/container/tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct deduce_mixed_width_dtype { +static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represent Operand, Scale, and Bias, respectively."); + +private: + using underlying_tuple = cute::conditional_t::value, Tuple, cute::tuple>; + static constexpr size_t valid_index = cute::min(I, cute::tuple_size_v - 1); + +public: + using type = cute::conditional_t<(I < cute::tuple_size_v), + cute::tuple_element_t, + void>; +}; + +template +using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype::type; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index becc077c..74c72cdc 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -187,6 +187,7 @@ constexpr bool is_tma_copy_engine() { || cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v + || cute::is_base_of_v ) { return true; } diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 54940f67..8559bbad 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -104,19 +104,22 @@ sm90_compute_tile_shape_or_override() { if constexpr (cute::is_same_v) { if constexpr (detail::sm90_is_cooperative_v) { + using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{}))); if constexpr (size<0>(TileShape_MNK{}) >= 128) { - return Shape<_128,_32>{}; + return Shape<_128, N_tile>{}; } else { - return Shape<_64,_32>{}; + return Shape<_64, N_tile>{}; } } else if constexpr (detail::sm90_is_warp_specialized_v) { if constexpr (sizeof_bits_v == 8) { - return Shape<_64,_64>{}; + using N_tile = decltype(cute::min(_64{}, get<1>(TileShape_MNK{}))); + return Shape<_64, N_tile>{}; } else { - return Shape<_64,_32>{}; + using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{}))); + return Shape<_64,N_tile>{}; } } else { @@ -265,6 +268,13 @@ struct Sm90TmaBuilderImpl { using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + using CopyOpS2G = + SM90_TMA_STORE + ; + using CopyOpG2S = + SM90_TMA_LOAD + ; + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination using FusionCallbacks = @@ -285,10 +295,10 @@ struct Sm90TmaBuilderImpl { ElementD, GmemStrideTypeD, FusionCallbacks, - SM90_TMA_LOAD, + CopyOpG2S, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm90_get_smem_load_op_for_source()), - SM90_TMA_STORE, + CopyOpS2G, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm90_get_smem_store_op_for_accumulator()) >; @@ -400,6 +410,7 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, + class Schedule, FloatRoundStyle RoundStyle > struct CollectiveBuilder< @@ -416,9 +427,11 @@ struct CollectiveBuilder< ElementD, GmemLayoutTagD, AlignmentD, - NoSmemWarpSpecialized, - fusion::LinearCombination, - void> { + Schedule, + fusion::LinearCombination, + cute::enable_if_t || + cute::is_same_v || + cute::is_same_v >> { // Passing void C disables source load using ElementC = cute::conditional_t, @@ -433,12 +446,21 @@ struct CollectiveBuilder< ElementD, FragmentSize, ElementAccumulator, ElementCompute, ScaleType, RoundStyle, ElementC>; - using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< - cutlass::epilogue::collective::DefaultEpilogue< - cutlass::detail::TagToStrideC_t, - cutlass::detail::TagToStrideC_t, - ThreadOp, - cutlass::gemm::EpilogueDefault> + using CollectiveOp = cute::conditional_t< + cute::is_same_v, + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueDefault>>, + // Epilogue for Ptr-Array and Grouped Gemm + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogueArray< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + Schedule>> >; }; @@ -533,6 +555,9 @@ struct CollectiveBuilder< FusionOperation, void> { private: + static_assert(cute::is_same_v>, + "Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead."); + // Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance) // since TMA epilogues are not compatible with non-TMA non-WS mainloops using EpilogueSchedule = NoSmemWarpSpecialized; @@ -595,7 +620,7 @@ CollectiveBuilder< cute::is_base_of_v >> { private: using FusionOp = - fusion::LinCombEltAct; + fusion::LinCombEltAct; using ImplSchedule = cute::conditional_t, TmaWarpSpecialized, TmaWarpSpecializedCooperative>; @@ -677,7 +702,7 @@ private: GmemStrideTypeAux, typename Schedule::ElementT>()); using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux< GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute, - typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute + typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute >; using FusionCallbacksAux = fusion::FusionCallbacks< DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux @@ -685,7 +710,7 @@ private: using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct< Schedule::template ActivationFunctor, ElementD, ElementCompute, - typename Schedule::ElementBias, ElementCompute + typename Schedule::ElementBias, ElementC_, ElementCompute >; using FusionCallbacksNoAux = fusion::FusionCallbacks< DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN @@ -750,7 +775,7 @@ struct CollectiveBuilder< GmemLayoutTagD, AlignmentD, cutlass::gemm::EpilogueTransposed, - fusion::LinearCombination, + fusion::LinearCombination, void> { // Passing void C disables source load using ElementC = cute::conditional_t, diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index 02cb795b..cb4e3825 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -62,7 +62,7 @@ template < class GmemLayoutTagD, int AlignmentD, class Schedule, - class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, class Enable = void > struct CollectiveBuilder { diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index 36ccdce0..4b10b06e 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -54,6 +54,7 @@ class CollectiveEpilogue { #include "detail.hpp" #include "default_epilogue.hpp" +#include "default_epilogue_array.hpp" #include "epilogue_tensor_broadcast.hpp" #include "sm70_epilogue_vectorized.hpp" #include "sm90_epilogue_tma_warpspecialized.hpp" diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index 99286cec..74c7e63f 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -131,8 +131,9 @@ public: return true; } + // Note: SharedStorage is unused for DefaultEpilogue CUTLASS_HOST_DEVICE - DefaultEpilogue(Params const& params_) + DefaultEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) : params(params_), epilogue_op(params_.thread) { } CUTLASS_DEVICE diff --git a/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/include/cutlass/epilogue/collective/default_epilogue_array.hpp new file mode 100644 index 00000000..49daa472 --- /dev/null +++ b/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/int.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Applies an element wise operation to all elements within the fragment +// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class DefaultEpilogueArray { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using StridesC = cute::conditional_t, + StrideC const*, StrideC>; + using StridesD = cute::conditional_t, + StrideD const*, StrideD>; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::is_same_v || + cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StridesC dC{}; + ElementD** ptr_D = nullptr; + StridesD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + DefaultEpilogueArray(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + StrideC stride_c; + StrideD stride_d; + if constexpr (cute::is_same_v) { + stride_c = detail::get_epilogue_stride(params.dC[l_coord]); + stride_d = detail::get_epilogue_stride(params.dD[l_coord]); + } + else { + stride_c = detail::get_epilogue_stride(params.dC); + stride_d = detail::get_epilogue_stride(params.dD); + } + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C[l_coord]), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index e6106343..afb62b05 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -170,7 +170,8 @@ public: [[maybe_unused]] TileCoordMNKL tile_coord_mnkl, [[maybe_unused]] TiledMma tiled_mma, [[maybe_unused]] int thread_idx, - [[maybe_unused]] TensorStorage& shared_tensors) + [[maybe_unused]] TensorStorage& shared_tensors, + [[maybe_unused]] int subtile_idx=-1) { return load_pipe_producer_state; } @@ -202,7 +203,8 @@ public: cute::Tensor accumulators, TiledMma tiled_mma, int thread_idx, - TensorStorage& shared_tensors) + TensorStorage& shared_tensors, + int subtile_index = -1) { constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK); auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 8f4c11f8..f8ae4dcf 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -85,6 +85,7 @@ public: using CopyAtomR2G = CopyAtomR2G_; static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -179,7 +180,7 @@ public: // synchronizing function for smem reads/writes #if CUDA_BARRIER_ENABLED - auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0); }; + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; #else auto synchronize = [] () { __syncthreads(); }; #endif diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index dcb6a07a..bd0ed764 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -109,8 +109,8 @@ public: using CopyOpR2S = CopyOpR2S_; using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; - using GmemTiledCopyC = SM90_TMA_LOAD; - using GmemTiledCopyD = SM90_TMA_STORE; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); @@ -198,12 +198,12 @@ public: struct Params { using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, - make_tensor(static_cast(nullptr), + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideC{}, int32_t(0)), StrideC{}), SmemLayoutC{}(_,_,0))); using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, - make_tensor(static_cast(nullptr), + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideD{}, int32_t(0)), StrideD{}), SmemLayoutD{}(_,_,0))); @@ -225,14 +225,20 @@ public: // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; + auto M_C = + size(M) + ; + auto M_D = + size(M) + ; typename Params::TMA_C tma_load_c; if constexpr (not cute::is_void_v) { - Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC)); + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC)); tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0)); } - Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD)); typename Params::TMA_D tma_store_d = make_tma_copy( CopyOpS2G{}, tensor_d, @@ -332,25 +338,31 @@ public: TileCoordMNKL tile_coord_mnkl, TiledMma tiled_mma, int thread_idx, - TensorStorage& shared_tensors) { + TensorStorage& shared_tensors, + int subtile_idx=-1) { using namespace cute; // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto coord_shape = + make_coord(m_coord, n_coord, l_coord) + ; + // Tile residue - auto m_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { + auto m_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); })); - auto n_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { + auto n_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); })); auto residue_mn = make_coord(m_max_coord, n_max_coord); // Represent the full source tensor, slice to get the tile this CTA is currently responsible for - Tensor mC = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N) + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) // Apply epilogue subtile, get matching smem tensor SmemElementC* ptr_sC = reinterpret_cast(shared_tensors.smem_D.data()); @@ -391,6 +403,9 @@ public: for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { CUTLASS_PRAGMA_UNROLL for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } // Acquire the lock for this stage constexpr uint16_t mcast_mask = 0; uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); @@ -449,31 +464,36 @@ public: cute::Tensor accumulators, TiledMma tiled_mma, int thread_idx, - TensorStorage& shared_tensors) { + TensorStorage& shared_tensors, + int subtile_idx=-1) { using namespace cute; using ElementAccumulator = typename AccEngine::value_type; using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; static_assert(is_rmem::value, "Accumulator must be RF resident."); - static_assert(cute::rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); - static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "TileShapeMNK must be static"); - static_assert(cute::rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); - static_assert(cute::rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{}); - auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{}); + auto mma_tile_m = tile_size<0>(tiled_mma); + auto mma_tile_n = tile_size<1>(tiled_mma); auto epi_tile_m = size<0>(EpilogueTile{}); auto epi_tile_n = size<1>(EpilogueTile{}); + auto coord_shape = + make_coord(m_coord, n_coord, l_coord) + ; // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mD = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N) + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) // Apply epilogue subtiling Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) @@ -530,11 +550,11 @@ public: Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) // Coordinate tensors and residue for tile quantization - auto m_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { + auto m_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl); return cute::max(0, c_m); })); - auto n_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { + auto n_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl); return cute::max(0, c_n); })); @@ -559,13 +579,13 @@ public: tRS_cD, tRS_rC }; - auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 0); }; + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; // Predication for TMA store (one warp issues TMA store) bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; @@ -594,9 +614,12 @@ public: for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } // The current tile in accumulator int mma_m = epi_m; - int mma_n = (epi_n * epi_tile_n) / mma_tile_n; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); if (is_producer_load_needed) { diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 639115c4..ecc871c0 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -46,6 +46,8 @@ namespace cutlass::epilogue { ////////////////////////////////////////////////////////////////////////////// struct NoSmemWarpSpecialized {}; +struct NoSmemWarpSpecializedArray {}; +struct NoSmemWarpSpecializedGroup {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; // DEPRECATED schedules, will be removed in next release diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 82e12404..8867ab9f 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -50,6 +50,8 @@ struct FusionOperation { // metadata types/queries that can be overrided using ElementOutput = void; using ElementCompute = void; + + using ElementSource = void; static constexpr bool IsSourceSupported = false; using ElementScalar = void; @@ -96,11 +98,13 @@ struct ScaledAcc : FusionOperation { template< class ElementOutput_, class ElementCompute_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinearCombination : ScaledAcc { + using ElementSource = ElementSource_; static constexpr bool IsSourceSupported = true; }; @@ -109,11 +113,12 @@ template< template class ActivationFn_, class ElementOutput_, class ElementCompute_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombEltAct - : LinearCombination { + : LinearCombination { using ActivationFn = ActivationFn_; static constexpr bool IsEltActSupported = true; }; @@ -123,12 +128,13 @@ template< class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, int AlignmentBias_ = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombPerRowBias - : LinearCombination { + : LinearCombination { using ElementBias = ElementBias_; static constexpr int AlignmentBias = AlignmentBias_; static constexpr bool IsPerRowBiasSupported = true; @@ -140,13 +146,14 @@ template< class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, int AlignmentBias_ = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombPerRowBiasEltAct : LinCombPerRowBias { + ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> { using ActivationFn = ActivationFn_; static constexpr bool IsEltActSupported = true; }; @@ -160,6 +167,7 @@ template< class ElementCompute_, class ElementAux_ = ElementOutput_, class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, int AlignmentAux_ = 128 / sizeof_bits_v, int AlignmentBias_ = 128 / sizeof_bits_v, @@ -167,7 +175,7 @@ template< > struct LinCombPerRowBiasEltActAux : LinCombPerRowBiasEltAct { + ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> { using ElementAux = ElementAux_; using GmemLayoutTagAux = GmemLayoutTagAux_; static constexpr int AlignmentAux = AlignmentAux_; @@ -180,6 +188,7 @@ template< class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, // per-row alpha/beta int AlignmentBias_ = 128 / sizeof_bits_v, int AlignmentScalar_ = 128 / sizeof_bits_v, @@ -187,7 +196,7 @@ template< > struct PerRowLinCombPerRowBiasEltAct : LinCombPerRowBiasEltAct { + ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> { static constexpr int AlignmentScalar = AlignmentScalar_; static constexpr bool IsPerRowScaleSupported = true; }; @@ -202,13 +211,14 @@ template< class ElementOutput_, class ElementCompute_, class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, int AlignmentBias_ = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct ScaledLinCombPerRowBiasEltAct : LinCombPerRowBiasEltAct { + ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> { static constexpr bool IsScaleFactorSupported = true; }; @@ -231,6 +241,7 @@ template< class ElementAux_ = ElementOutput_, class ElementAmax_ = ElementCompute_, class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, int AlignmentAux_ = 128 / sizeof_bits_v, int AlignmentBias_ = 128 / sizeof_bits_v, @@ -238,7 +249,7 @@ template< > struct ScaledLinCombPerRowBiasEltActAmaxAux : ScaledLinCombPerRowBiasEltAct { + ElementBias_, ElementSource_, ElementScalar_, AlignmentBias_, RoundStyle_> { using ElementAmax = ElementAmax_; static constexpr bool IsAbsMaxSupported = true; @@ -257,12 +268,16 @@ template< class ElementOutput_, class ElementCompute_, class ElementAux_ = ElementOutput_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, int AlignmentAux_ = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest > struct LinCombDeEltAct - : LinCombEltAct { + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsDeEltActSupported = true; + using ElementAux = ElementAux_; using GmemLayoutTagAux = GmemLayoutTagAux_; static constexpr int AlignmentAux = AlignmentAux_; @@ -280,6 +295,7 @@ template< class ElementCompute_, class ElementAux_ = ElementOutput_, class ElementBias_ = ElementCompute_, + class ElementSource_ = ElementOutput_, class ElementScalar_ = ElementCompute_, int AlignmentAux_ = 128 / sizeof_bits_v, int AlignmentBias_ = 128 / sizeof_bits_v, @@ -287,7 +303,7 @@ template< > struct LinCombDeEltActDePerRowBias : LinCombDeEltAct { + ElementAux_, ElementSource_, ElementScalar_, AlignmentAux_, RoundStyle_> { using ElementBias = ElementBias_; static constexpr int AlignmentBias = AlignmentBias_; static constexpr bool IsDePerRowBiasSupported = true; diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index da392ce1..a9f9456e 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -113,13 +113,14 @@ struct FusionCallbacks< template< class ElementOutput, class ElementCompute, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinearCombination = Sm90EVT, // beta * C + (alpha * acc) Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc Sm90ScalarBroadcast, // alpha Sm90AccFetch // acc @@ -133,6 +134,7 @@ template < bool ReuseSmemC, class ElementOutput, class ElementCompute, + class ElementSource, class ElementScalar, FloatRoundStyle RoundStyle, class CtaTileShapeMNK, @@ -140,13 +142,13 @@ template < > struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, - fusion::LinearCombination, + fusion::LinearCombination, CtaTileShapeMNK, EpilogueTile -> : Sm90LinearCombination::type, ElementCompute, ElementScalar, RoundStyle> { +> : Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { - using Impl = Sm90LinearCombination::type, ElementCompute, ElementScalar, RoundStyle>; - using Operation = fusion::LinearCombination; + using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; struct Arguments { ElementScalar alpha = ElementScalar(1); @@ -180,12 +182,13 @@ template< template class ActivationFn, class ElementOutput, class ElementCompute, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombEltAct = Sm90EVT, // activation(beta * C + (alpha * acc)) - Sm90LinearCombination // beta * C + (alpha * acc) + Sm90LinearCombination // beta * C + (alpha * acc) >; template < @@ -196,6 +199,7 @@ template < template class ActivationFn, class ElementOutput, class ElementCompute, + class ElementSource, class ElementScalar, FloatRoundStyle RoundStyle, class CtaTileShapeMNK, @@ -203,13 +207,13 @@ template < > struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombEltAct, + fusion::LinCombEltAct, CtaTileShapeMNK, EpilogueTile -> : Sm90LinCombEltAct { +> : Sm90LinCombEltAct { - using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementScalar, RoundStyle>; - using Operation = fusion::LinCombEltAct; + using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; struct Arguments { ElementScalar alpha = ElementScalar(1); @@ -250,6 +254,7 @@ template< class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest @@ -257,7 +262,7 @@ template< using Sm90LinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc @@ -273,6 +278,7 @@ template < class ElementOutput, class ElementCompute, class ElementBias, + class ElementSource, class ElementScalar, int AlignmentBias, FloatRoundStyle RoundStyle, @@ -281,15 +287,15 @@ template < > struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, - fusion::LinCombPerRowBias, + fusion::LinCombPerRowBias, CtaTileShapeMNK, EpilogueTile > : Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> { + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { using Impl = Sm90LinCombPerRowBias< - CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>; + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; using Operation = fusion::LinCombPerRowBias< - ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>; + ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; struct Arguments { ElementScalar alpha = ElementScalar(1); @@ -330,13 +336,14 @@ template< class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombPerRowBiasEltAct = Sm90EVT, - Sm90LinCombPerRowBias + Sm90LinCombPerRowBias >; template < @@ -348,6 +355,7 @@ template < class ElementOutput, class ElementCompute, class ElementBias, + class ElementSource, class ElementScalar, int AlignmentBias, FloatRoundStyle RoundStyle, @@ -357,21 +365,21 @@ template < struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, fusion::LinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle >, CtaTileShapeMNK, EpilogueTile > : Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle > { using Impl = Sm90LinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle >; using Operation = fusion::LinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle >; struct Arguments { @@ -426,6 +434,7 @@ template< class ElementCompute, class ElementAux = ElementOutput, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentAux = 128 / sizeof_bits_v, int AlignmentBias = 128 / sizeof_bits_v, @@ -434,7 +443,7 @@ template< using Sm90LinCombPerRowBiasEltActAux = Sm90EVT, Sm90EVT, - Sm90LinCombPerRowBias + Sm90LinCombPerRowBias > >; @@ -449,6 +458,7 @@ template < class ElementCompute, class ElementAux, class ElementBias, + class ElementSource, class ElementScalar, int AlignmentAux, int AlignmentBias, @@ -462,7 +472,7 @@ struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, fusion::LinCombPerRowBiasEltActAux< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >, CtaTileShapeMNK, EpilogueTile, @@ -470,18 +480,18 @@ struct FusionCallbacks< CopyOpR2S > : Sm90LinCombPerRowBiasEltActAux< CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > { using Impl = Sm90LinCombPerRowBiasEltActAux< CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; using Operation = fusion::LinCombPerRowBiasEltActAux< GmemLayoutTagAux, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; struct Arguments { @@ -535,6 +545,7 @@ template< class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, int AlignmentScalar = 128 / sizeof_bits_v, @@ -543,7 +554,7 @@ template< using Sm90PerRowLinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // alpha Sm90AccFetch, // acc @@ -558,6 +569,7 @@ template< class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, int AlignmentScalar = 128 / sizeof_bits_v, @@ -566,7 +578,7 @@ template< using Sm90PerRowLinCombPerRowBiasEltAct = Sm90EVT, Sm90PerRowLinCombPerRowBias + ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle> >; template < @@ -578,6 +590,7 @@ template < class ElementOutput, class ElementCompute, class ElementBias, + class ElementSource, class ElementScalar, int AlignmentBias, int AlignmentScalar, @@ -588,21 +601,21 @@ template < struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, fusion::PerRowLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle >, CtaTileShapeMNK, EpilogueTile > : Sm90PerRowLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle > { using Impl = Sm90PerRowLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle >; using Operation = fusion::PerRowLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle >; struct Arguments { @@ -664,6 +677,7 @@ template< class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest @@ -671,7 +685,7 @@ template< using Sm90ScaledLinCombPerRowBias = Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, 2>, // scale_c * beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha Sm90AccFetch, // acc @@ -690,6 +704,7 @@ template< class ElementOutput, class ElementCompute, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentBias = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest @@ -698,7 +713,7 @@ using Sm90ScaledLinCombPerRowBiasEltAct = Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT, // activation(Z) // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias - Sm90ScaledLinCombPerRowBias + Sm90ScaledLinCombPerRowBias >, Sm90ScalarBroadcast // scale_d >; @@ -712,6 +727,7 @@ template < class ElementOutput, class ElementCompute, class ElementBias, + class ElementSource, class ElementScalar, int AlignmentBias, FloatRoundStyle RoundStyle, @@ -721,21 +737,21 @@ template < struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle >, CtaTileShapeMNK, EpilogueTile > : Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle > { using Impl = Sm90ScaledLinCombPerRowBiasEltAct< - CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle >; using Operation = fusion::ScaledLinCombPerRowBiasEltAct< - ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle >; struct Arguments { @@ -819,6 +835,7 @@ template< class ElementAux = ElementOutput, class ElementAmax = ElementCompute, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentAux = 128 / sizeof_bits_v, int AlignmentBias = 128 / sizeof_bits_v, @@ -827,7 +844,7 @@ template< using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = Sm90SplitTreeVisitor< // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, + Sm90ScaledLinCombPerRowBias, // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d Sm90EVT, // amax_d @@ -860,6 +877,7 @@ template < class ElementAux, class ElementAmax, class ElementBias, + class ElementSource, class ElementScalar, int AlignmentAux, int AlignmentBias, @@ -873,7 +891,7 @@ struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, fusion::ScaledLinCombPerRowBiasEltActAmaxAux< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >, CtaTileShapeMNK, EpilogueTile, @@ -882,19 +900,19 @@ struct FusionCallbacks< > : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > { using Impl = Sm90ScaledLinCombPerRowBiasEltActAmaxAux< CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; using Operation = fusion::ScaledLinCombPerRowBiasEltActAmaxAux< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; struct Arguments { @@ -1014,13 +1032,14 @@ template< class ElementOutput, class ElementCompute, class ElementAux = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentAux = 128 / sizeof_bits_v, FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombDeEltAct = Sm90EVT, // activation(beta * C + (alpha * acc), aux) - Sm90LinearCombination, // beta * C + (alpha * acc) + Sm90LinearCombination, // beta * C + (alpha * acc) Sm90AuxLoad // aux >; @@ -1034,6 +1053,7 @@ template < class ElementOutput, class ElementCompute, class ElementAux, + class ElementSource, class ElementScalar, int AlignmentAux, FloatRoundStyle RoundStyle, @@ -1046,7 +1066,7 @@ struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, fusion::LinCombDeEltAct< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementScalar, AlignmentAux, RoundStyle + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle >, CtaTileShapeMNK, EpilogueTile, @@ -1054,18 +1074,18 @@ struct FusionCallbacks< CopyOpS2R > : Sm90LinCombDeEltAct< CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle > { using Impl = Sm90LinCombDeEltAct< CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle >; using Operation = fusion::LinCombDeEltAct< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementScalar, AlignmentAux, RoundStyle + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle >; struct Arguments { @@ -1118,6 +1138,7 @@ template< class ElementCompute, class ElementAux = ElementOutput, class ElementBias = ElementOutput, + class ElementSource = ElementOutput, class ElementScalar = ElementCompute, int AlignmentAux = 128 / sizeof_bits_v, int AlignmentBias = 128 / sizeof_bits_v, @@ -1128,7 +1149,7 @@ using Sm90LinCombDeEltActDePerRowBias = Sm90EVT, AlignmentBias>, Sm90LinCombDeEltAct + ElementCompute, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle> > >; @@ -1143,6 +1164,7 @@ template < class ElementCompute, class ElementAux, class ElementBias, + class ElementSource, class ElementScalar, int AlignmentAux, int AlignmentBias, @@ -1156,7 +1178,7 @@ struct FusionCallbacks< epilogue::Sm90TmaWarpSpecialized, fusion::LinCombDeEltActDePerRowBias< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >, CtaTileShapeMNK, EpilogueTile, @@ -1164,18 +1186,18 @@ struct FusionCallbacks< CopyOpS2R > : Sm90LinCombDeEltActDePerRowBias< CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle > { using Impl = Sm90LinCombDeEltActDePerRowBias< CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, - ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; using Operation = fusion::LinCombDeEltActDePerRowBias< GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, - ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle >; struct Arguments { diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index e5332fc1..6887098a 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -215,16 +215,17 @@ template < class StrideScalar, int ScalarCount, template class ScalarReduceFn, + class ElementSource, class InputAddOp // Z > struct Sm90TreeVisitor< Sm90Compute, Sm90ScalarBroadcast, - Sm90SrcFetch, + Sm90SrcFetch, InputAddOp > : Sm90VisitorImpl< Sm90ScalarBroadcast, - Sm90SrcFetch, + Sm90SrcFetch, InputAddOp, Sm90Compute > @@ -232,11 +233,10 @@ struct Sm90TreeVisitor< using Impl = Sm90VisitorImpl< Sm90ScalarBroadcast, - Sm90SrcFetch, + Sm90SrcFetch, InputAddOp, Sm90Compute >; - using Params = typename Impl::Params; using SharedStorage = typename Impl::SharedStorage; @@ -260,8 +260,9 @@ struct Sm90TreeVisitor< CUTLASS_DEVICE bool is_C_load_needed() const { auto const& bcast_op = get<0>(Impl::ops); + auto const& src_op = get<1>(Impl::ops); auto const& added_op = get<2>(Impl::ops); - return bcast_op.scalar != 0 || added_op.is_C_load_needed(); + return (bcast_op.scalar != 0 && src_op.is_C_load_needed()) || added_op.is_C_load_needed(); } template @@ -320,17 +321,9 @@ struct Sm90TreeVisitor< // ReLU with aux bit tensor dReLU/dZ // Aux(i) = Z(i) >= 0 ? 1 : 0 namespace detail { -template < - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - class StrideMNL, - int Alignment, - bool EnableNullptr -> -struct Sm90ReLUAuxStore { - static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); - +// Placeholder node so we can retain standard EVT structure +template +struct Sm90ReLUAuxStore : Sm90VisitorImpl<> { struct SharedStorage {}; struct Arguments { @@ -362,41 +355,90 @@ struct Sm90ReLUAuxStore { Sm90ReLUAuxStore() { } CUTLASS_HOST_DEVICE - Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) - : params(params) { } + Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { } +}; +} // namespace detail - Params const params; +// Specialization on the generic compute+aux EVT +template < + // Compute node + template class Activation, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + // Aux node + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment, + bool EnableNullptr, + // Input node + class InputOp +> +struct Sm90TreeVisitor< + Sm90Compute, cutlass::epilogue::thread::ReLu> || + is_same_v, cutlass::epilogue::thread::Clamp> >>, + Sm90TreeVisitor< + Sm90AuxStore< + Stages, + EpilogueTile, + cutlass::uint1b_t, + RoundStyle, + StrideMNL, + SmemLayoutAtom, + CopyOpR2S, + Alignment, + EnableNullptr + >, + InputOp + > +> : Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; - CUTLASS_DEVICE bool - is_producer_load_needed() const { - return false; - } + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} - CUTLASS_DEVICE bool - is_C_load_needed() const { - return false; - } + CUTLASS_HOST_DEVICE + Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage) + : params(params_), Impl(params_, shared_storage) {} - template - CUTLASS_DEVICE auto - get_producer_load_callbacks(ProducerLoadArgs const& args) { - return EmptyProducerLoadCallbacks{}; - } + Params const& params; - template - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + struct ConsumerStoreCallbacks : CallbacksImpl { CUTLASS_DEVICE ConsumerStoreCallbacks( RTensor&& tC_rAux, GTensor&& tC_gAux, CTensor tC_cAux, ResidueMN residue_mn, - Params const& params) + Params const& params, + CallbacksImpl&& impl) : tC_rAux(cute::forward(tC_rAux)), tC_gAux(cute::forward(tC_gAux)), tC_cAux(tC_cAux), residue_mn(residue_mn), - params(params) {} + params(params), + CallbacksImpl(cute::forward(impl)) {} RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) @@ -404,13 +446,23 @@ struct Sm90ReLUAuxStore { ResidueMN residue_mn; Params const& params; - template - CUTLASS_DEVICE auto - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const& frg_input) { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); + + // Compute activation + aux + using ElementInput = typename decltype(frg_input)::Element; using ConvertInput = NumericArrayConverter; using ConvertAux = PackPredicates; - using ComputeOutput = cutlass::epilogue::thread::ReLu; + using ComputeOutput = Activation; using ConvertOutput = NumericArrayConverter; ConvertInput convert_input{}; ComputeOutput relu{}; @@ -422,7 +474,12 @@ struct Sm90ReLUAuxStore { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { ElementCompute pre_relu = frg_compute[i]; - frg_compute[i] = relu(frg_compute[i]); + if constexpr (is_same_v, cutlass::epilogue::thread::Clamp>) { + frg_compute[i] = relu(frg_compute[i], params_compute); + } + else { + frg_compute[i] = relu(frg_compute[i]); + } frg_aux[i] = frg_compute[i] == pre_relu; } @@ -435,8 +492,18 @@ struct Sm90ReLUAuxStore { CUTLASS_DEVICE void end() { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + callbacks_input.end(); + + // Nullptr is no-op if constexpr (EnableNullptr) { - if (params.ptr_aux == nullptr) { + if (params_aux.ptr_aux == nullptr) { return; } } @@ -473,114 +540,25 @@ struct Sm90ReLUAuxStore { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + // Unpack params + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params.ptr_aux)); - Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) + gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params_aux.ptr_aux)); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) gAux, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); + auto callbacks_impl = Impl::template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params, cute::move(callbacks_impl)); } }; -} // namespace detail - -// Specialization on the generic compute+aux EVT -template < - // Compute node - template class Activation, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle, - // Aux node - int Stages, - class EpilogueTile, - class StrideMNL, - class SmemLayoutAtom, - class CopyOpR2S, - int Alignment, - bool EnableNullptr, - // Input node - class InputOp -> -struct Sm90TreeVisitor< - Sm90Compute, cutlass::epilogue::thread::ReLu>, void>>, - Sm90TreeVisitor< - Sm90AuxStore< - Stages, - EpilogueTile, - cutlass::uint1b_t, - RoundStyle, - StrideMNL, - SmemLayoutAtom, - CopyOpR2S, - Alignment, - EnableNullptr - >, - InputOp - > -> : Sm90VisitorImpl< - Sm90VisitorImpl< - InputOp, - detail::Sm90ReLUAuxStore - >, - Sm90Compute - > -{ - using Impl = - Sm90VisitorImpl< - Sm90VisitorImpl< - InputOp, - detail::Sm90ReLUAuxStore - >, - Sm90Compute - >; - - using Params = typename Impl::Params; - using SharedStorage = typename Impl::SharedStorage; - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor() {} - - CUTLASS_HOST_DEVICE - Sm90TreeVisitor( - Params const& params, - SharedStorage const& shared_storage) - : Impl(params, shared_storage) {} - - template - struct ConsumerStoreCallbacks : CallbacksImpl { - CUTLASS_DEVICE - ConsumerStoreCallbacks(CallbacksImpl&& impl) - : CallbacksImpl(cute::forward(impl)) { } - - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { - auto& [callbacks_input, callbacks_relu_aux] = get<0>(CallbacksImpl::callbacks_tuple).callbacks_tuple; - - Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); - return callbacks_relu_aux.visit(frg_acc, epi_v, epi_m, epi_n, frg_input); - } - }; - - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class... Args - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); - return ConsumerStoreCallbacks(std::move(callbacks_tuple)); - } - -}; // Aux load for uint1b_t template < diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index df3b988b..c37f1dae 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -85,16 +85,17 @@ using Sm90SplitTreeFetch = Sm90AccFetch; ///////////////////////////////////////////////////////////////////////////////////////////////// // returns C +template struct Sm90SrcFetch : Sm90VisitorImpl<> { CUTLASS_DEVICE bool is_producer_load_needed() const { - return true; + return is_C_load_needed(); } CUTLASS_DEVICE bool is_C_load_needed() const { - return true; + return not is_void_v; } using Sm90VisitorImpl<>::Sm90VisitorImpl; @@ -105,7 +106,6 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> { ConsumerStoreCallbacks(SrcTensor const& tCrC) : tCrC(tCrC) {} - // make this a pointer if we need default ctor for generic tuple of visitors SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N) template @@ -122,7 +122,7 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - + // register type may differ from logical type so we can't assert matching types here return ConsumerStoreCallbacks(args.tCrC); } }; @@ -344,7 +344,6 @@ struct Sm90AuxLoad { make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) - return ConsumerStoreCallbacks( cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 13780f3e..9dc9fc8e 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -303,7 +303,7 @@ private: static constexpr bool IsAtomic = is_atomic>::value; static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); -public: +public: struct SharedStorage { }; struct Arguments { @@ -814,7 +814,7 @@ public: } auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); @@ -843,8 +843,8 @@ public: return; } - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; auto [m, n, k, l] = tile_coord_mnkl; constexpr bool ReferenceSrc = decltype(ref_src)::value; @@ -1002,15 +1002,15 @@ public: return; } - auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; using ReduceOutput = GmemReduceFn; using ConvertOutput = NumericConverter; ReduceOutput reduce_output{}; ConvertOutput convert_output{}; - + // Reduction over batches if (size<2>(stride(gCol_l)) == 0) { CUTLASS_PRAGMA_NO_UNROLL @@ -1051,8 +1051,8 @@ public: CUTLASS_DEVICE bool is_reduction_buffer_needed(int epi_m, int epi_n, bool is_last_iteration) const { - auto const& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + auto const& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; return (not IsAtomic && // atomic reduction doesn't use smem @@ -1111,7 +1111,7 @@ public: auto args_tuple = make_tuple( bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, - lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks(std::move(args_tuple), params); } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 30ac4761..6b625df9 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -563,7 +563,6 @@ struct Sm90TreeVisitor : Sm90VisitorImpl { template get_consumer_store_callbacks(args); return ConsumerStoreCallbacks(std::move(callbacks_tuple)); } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -614,7 +613,6 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl(args); return ConsumerStoreCallbacks(std::move(callbacks_tuple)); } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -692,7 +690,6 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl { template get_consumer_store_callbacks(args); return ConsumerStoreCallbacks(std::move(callbacks_tuple)); } - }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h index 0549d753..2712a4af 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_relu.h @@ -49,7 +49,6 @@ namespace cutlass { namespace epilogue { namespace thread { - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { @@ -66,13 +65,65 @@ struct ArrayMaximum { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ElementsPerAccess; ++i) { - result[i] = fmax(lhs[i], rhs[i]); + result[i] = platform::max(lhs[i].get(), rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &lhs, + Element rhs) const { + + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = platform::max(lhs[i].get(), rhs); } return result; } }; + +/// Partial specialization: Element=float +template +struct ArrayMaximum { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &lhs, + Array const &rhs) const { + + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = fmax(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &lhs, + float rhs) const { + + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = fmax(lhs[i], rhs); + } + + return result; + } +}; + +/// Partial specialization: Element=half template struct ArrayMaximum { @@ -96,6 +147,8 @@ struct ArrayMaximum { res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); } + static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length."); + #else __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); __half const *rhs_ptr = reinterpret_cast<__half const *>(rhs.raw_data()); @@ -133,6 +186,8 @@ struct ArrayMaximum { res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); } + static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length."); + #else __half const *lhs_ptr = reinterpret_cast<__half const *>(lhs.raw_data()); @@ -150,6 +205,90 @@ struct ArrayMaximum { } }; +/// Partial specialization: Element=bfloat16_t +template +struct ArrayMaximum { + + using NvType = __nv_bfloat16; + using NvTypeV2 = __nv_bfloat162; + + CUTLASS_DEVICE + Array operator()( + Array const &lhs, + Array const &rhs) const { + + Array result; + + #if __CUDA_ARCH__ >= 800 + int const kVectorCount = ElementsPerAccess / 2; + + + NvTypeV2 const *lhs_ptr = reinterpret_cast(lhs.raw_data()); + NvTypeV2 const *rhs_ptr = reinterpret_cast(rhs.raw_data()); + NvTypeV2 *res_ptr = reinterpret_cast(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); + } + + #else + NvType const *lhs_ptr = reinterpret_cast(lhs.raw_data()); + NvType const *rhs_ptr = reinterpret_cast(rhs.raw_data()); + NvType *res_ptr = reinterpret_cast(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + res_ptr[i] = ((lhs_ptr[i] < rhs_ptr[i]) ? rhs_ptr[i] : lhs_ptr[i]); + } + + #endif + + return result; + } + + CUTLASS_DEVICE + Array operator()( + Array const &lhs, + bfloat16_t rhs) const { + + Array result; + + #if __CUDA_ARCH__ >= 800 + int const kVectorCount = ElementsPerAccess / 2; + + + NvType rhs_raw = reinterpret_cast(rhs); + NvTypeV2 rhs_pair = __bfloat162bfloat162(rhs_raw); + + NvTypeV2 const *lhs_ptr = reinterpret_cast(lhs.raw_data()); + NvTypeV2 *res_ptr = reinterpret_cast(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); + } + + static_assert(!(ElementsPerAccess % 2), "Output array must be divisible by vector length."); + + #else + + NvType const *lhs_ptr = reinterpret_cast(lhs.raw_data()); + NvType const rhs_raw = reinterpret_cast(rhs); + NvType *res_ptr = reinterpret_cast(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + res_ptr[i] = ((lhs_ptr[i] < rhs_raw) ? rhs_raw : lhs_ptr[i]); + } + + #endif + + return result; + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// template @@ -187,6 +326,25 @@ struct ReluConditional { } }; +template +struct ReluConditional { + + CUTLASS_DEVICE + void operator()( + bool conditional[], + Array const &fragment, + bfloat16_t threshold) const { + + __nv_bfloat16 y = reinterpret_cast<__nv_bfloat16 const &>(threshold); + __nv_bfloat16 const *x = reinterpret_cast<__nv_bfloat16 const *>(fragment.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + conditional[i] = !__hlt(x[i], y); + } + } +}; + } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_dgelu.h b/include/cutlass/epilogue/thread/linear_combination_dgelu.h index a8254629..b9931a7c 100644 --- a/include/cutlass/epilogue/thread/linear_combination_dgelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_dgelu.h @@ -92,9 +92,9 @@ public: ElementCompute alpha; ///< scales accumulators ElementCompute beta; ///< scales source tensor - ElementCompute threshold; ///< minimum value that is output ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + ElementCompute threshold; ///< minimum value that is output // // Methods // diff --git a/include/cutlass/epilogue/thread/linear_combination_drelu.h b/include/cutlass/epilogue/thread/linear_combination_drelu.h index 44522d2d..85416180 100644 --- a/include/cutlass/epilogue/thread/linear_combination_drelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_drelu.h @@ -87,9 +87,9 @@ public: ElementCompute alpha; ///< scales accumulators ElementCompute beta; ///< scales source tensor - ElementCompute threshold; ///< minimum value that is output ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + ElementCompute threshold; ///< minimum value that is output // // Methods // diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 6b7351cf..78e27197 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -1698,13 +1698,13 @@ private: // if (OutputOp::kStoreZ) { - destination_iterator += reduce_fragment_idx; destination_iterator.store(frag_Z); + ++destination_iterator; } if (OutputOp::kStoreT) { - tensor_iterator += reduce_fragment_idx; tensor_iterator.store(frag_T); + ++tensor_iterator; } } }; diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 7f1c242d..048dff3e 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -187,7 +187,7 @@ CUTLASS_CONSTEXPR_IF_CXX17 value_t lcm(value_t a, value_t b) { value_t temp = gcd(a, b); - return temp ? (cutlass::abs_for_integer(a) / temp * cutlass::abs_for_integer(b)) : 0; + return temp ? (cutlass::abs_for_integer(a) / temp * cutlass::abs_for_integer(b)) : value_t(); } /** @@ -207,8 +207,11 @@ template CUTLASS_HOST_DEVICE constexpr value_t lcm_cxx11(value_t a, value_t b) { - return gcd_cxx11(a, b) ? (cutlass::abs_for_integer(a) / gcd_cxx11(a, b) * cutlass::abs_for_integer(b)) : 0; + return gcd_cxx11(a, b) ? (cutlass::abs_for_integer(a) / gcd_cxx11(a, b) * + cutlass::abs_for_integer(b)) + : value_t(); } + /// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b CUTLASS_HOST_DEVICE CUTLASS_CONSTEXPR_IF_CXX17 diff --git a/include/cutlass/gemm/collective/builders/sm90_common.inl b/include/cutlass/gemm/collective/builders/sm90_common.inl index b9c76a64..004eb228 100644 --- a/include/cutlass/gemm/collective/builders/sm90_common.inl +++ b/include/cutlass/gemm/collective/builders/sm90_common.inl @@ -31,6 +31,7 @@ #pragma once #include "cutlass/detail/layout.hpp" +#include "cutlass/detail/collective.hpp" #include "cute/atom/mma_traits_sm90_gmma.hpp" #include "cute/atom/copy_traits_sm90_tma.hpp" @@ -322,13 +323,21 @@ is_input_fp8() { (cute::is_same_v || cute::is_same_v)); } -template +// We need to handle the tuples in this function since it is used in SFINAE dispatch in the CollectiveBuilder. +// At that point, it is not guaranteed that the tuples have been split out into the required parts. +template constexpr bool is_use_rmem_A() { + + using ElementA = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementA>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, MaybeTupleElementB>; + + constexpr bool IsABDifferentWidth = cute::sizeof_bits_v != cute::sizeof_bits_v; + constexpr bool HasScales = cute::is_tuple::value ^ cute::is_tuple::value; constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); constexpr bool IsLayoutAkBk = cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B(); - constexpr bool IsUseRmemA = !IsInputSizeTwoBytes && !IsLayoutAkBk; + constexpr bool IsUseRmemA = (!IsInputSizeTwoBytes && !IsLayoutAkBk) || IsABDifferentWidth || HasScales; return IsUseRmemA; } diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 8fa85d8d..5216dc1d 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -79,6 +79,50 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_cou return (CapacityBytes - carveout_bytes) / stage_bytes; } +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_single_affine_transformed_input(StageCount stage_count) { + return stages; +} + +template +constexpr int get_bits_for_possibly_void_element() { + if constexpr (cute::is_same_v) { + return 0; + } + else { + return sizeof_bits::value; + } +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_single_affine_transformed_input(StageCountAutoCarveout stage_count) { + + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int scale_zero_k_tile = 1; + constexpr int a_bits = static_cast(sizeof_bits::value); + constexpr int b_bits = static_cast(sizeof_bits::value); + constexpr int s_bits = get_bits_for_possibly_void_element(); + constexpr int z_bits = get_bits_for_possibly_void_element(); + + constexpr int scale_bytes = (s_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile) / 8; + constexpr int zero_bytes = (z_bits * size<0>(TileShapeMNK{}) * scale_zero_k_tile) / 8; + static_assert(scale_bytes % 128 == 0, "Scale bytes must be a multiple of 128"); + static_assert(zero_bytes % 128 == 0, "Zero bytes must be a multiple of 128"); + + // When scales are void, s_bits will be 0 so no smem will be allocated for scales. + constexpr int stage_bytes = + (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + scale_bytes + zero_bytes + stage_barrier_bytes; + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + template constexpr bool is_swapAB(){ @@ -105,29 +149,6 @@ is_warpspecialized_transpose_B(){ return IsWarpSpecializedTransposeB; } -template -struct Sm90TypeWidths { - static constexpr bool IsElementALarger = (cute::sizeof_bits_v) > cute::sizeof_bits_v; - using WideType = cute::conditional_t; - using NarrowType = cute::conditional_t; -}; - - -template -constexpr bool -sm90_is_narrow_type_k_major() { - using Widths = Sm90TypeWidths; - using NarrowType = typename Widths::NarrowType; - using WideType = typename Widths::WideType; - - constexpr bool IsANarrow = cute::is_same_v; - constexpr cute::GMMA::Major NarrowGmmaMajor = IsANarrow ? detail::gmma_rs_tag_to_major_A() : - detail::gmma_rs_tag_to_major_B(); - - constexpr bool IsNarrowLayoutKMajor = NarrowGmmaMajor == cute::GMMA::Major::K; - return IsNarrowLayoutKMajor; -} - } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -163,18 +184,24 @@ struct CollectiveBuilder< cute::enable_if_t< (cute::is_same_v || cute::is_same_v || - cute::is_same_v) && + cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && not detail::is_use_rmem_A()> > { static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); -static constexpr bool IsFP8Input = detail::is_input_fp8(); + static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), + "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); // For fp32 types, map to tf32 MMA value type using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; @@ -183,7 +210,8 @@ static constexpr bool IsFP8Input = detail::is_input_fp8(); static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - using AtomLayoutMNK = cute::conditional_t, + using AtomLayoutMNK = cute::conditional_t< + cute::is_same_v || IsArrayOfPointersGemm, Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< @@ -199,10 +227,12 @@ static constexpr bool IsFP8Input = detail::is_input_fp8(); static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - /* For FP8 use a separate mainloop compared to other datatypes */ - using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecialized>; + using DispatchPolicy = cute::conditional_t, + /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; @@ -267,7 +297,7 @@ struct CollectiveBuilder< static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); @@ -323,13 +353,13 @@ struct CollectiveBuilder< ///////////////////////////////////////////////////////////////////////////////////////////////// -// GMMA_TMA_WS_RS Mixed GEMM +// GMMA_TMA_WS_RS Mixed Scaled GEMM template < class ElementPairA_, - class GmemLayoutPairA_, + class GmemLayoutA_, int AlignmentA, class ElementPairB_, - class GmemLayoutPairB_, + class GmemLayoutB_, int AlignmentB, class ElementAccumulator, class TileShape_MNK, @@ -341,10 +371,10 @@ struct CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, ElementPairA_, - GmemLayoutPairA_, + GmemLayoutA_, AlignmentA, ElementPairB_, - GmemLayoutPairB_, + GmemLayoutB_, AlignmentB, ElementAccumulator, TileShape_MNK, @@ -357,22 +387,38 @@ struct CollectiveBuilder< cute::is_same_v)> > { +private: + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementPairA_>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementPairB_>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementPairA_>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementPairB_>; + static constexpr bool NeitherIsTuple = !cute::is_tuple::value && !cute::is_tuple::value; + public: - static constexpr bool IsATransformed = cute::sizeof_bits_v < cute::sizeof_bits_v; + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>; + static_assert(cute::is_tuple::value ^ cute::is_tuple::value || + (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), + "Either A OR B must be a tuple or the widths of A and B must be different."); - // Split out items for processessing, no splitting for now since scales aren't supported. - using ElementA = ElementPairA_; - using ElementB = ElementPairB_; + static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; - using GmemLayoutA = GmemLayoutPairA_; - using GmemLayoutB = GmemLayoutPairB_; + using GmemLayoutA = GmemLayoutA_; + using GmemLayoutB = GmemLayoutB_; + + using ElementPairA = cute::conditional_t, ElementPairA_>; + using ElementPairB = cute::conditional_t, ElementPairB_>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; static_assert(is_static::value); static_assert(is_static::value); static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); @@ -382,12 +428,7 @@ public: // If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands. static constexpr bool SwapAB = !IsATransformed; - static_assert(detail::sm90_is_narrow_type_k_major(), "The narrow type must be K-major."); - static_assert((IsATransformed && (cute::sizeof_bits_v <= 8) && (sizeof(ElementB) == 2)) || - (!IsATransformed && (cute::sizeof_bits_v <= 8) && (sizeof(ElementA) == 2)) || - (GmmaMajorA == cute::GMMA::Major::K && GmmaMajorB == cute::GMMA::Major::K), - "The unscaled element must be 2 bytes OR both inputs must be K-major"); // When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly. static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB; @@ -400,6 +441,7 @@ public: using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); using SmemLayoutAtomB = decltype(detail::rs_smem_selector; using RealElementB = cute::conditional_t; - static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + static constexpr int PipelineStages = detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}); using SmemCopyAtomA = cute::conditional_t>; using SmemCopyAtomB = cute::conditional_t, void>; @@ -416,35 +458,24 @@ public: using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; // We pack the scale data with the operand that will be optionally scaled and converted before MMA. - using StrideAPair = TagToStrideA_t; - using StrideBPair = TagToStrideB_t; + using StrideA = TagToStrideA_t; + using StrideB = TagToStrideB_t; - using GmemTiledCopyAPair = GmemTiledCopyA; - using SmemLayoutAtomAPair = SmemLayoutAtomA; - using SmemCopyAtomAPair = SmemCopyAtomA; - - using GmemTiledCopyBPair = GmemTiledCopyB; - using SmemLayoutAtomBPair = SmemLayoutAtomB; - using SmemCopyAtomBPair = SmemCopyAtomB; - - - // If the src type of the converter is the same as ElementA, - // interpret this as if the user wanted to apply the scale to the A matrix. using CollectiveOp = CollectiveMma< DispatchPolicy, TileShape_MNK, - ElementPairA_, - StrideAPair, - ElementPairB_, - StrideBPair, + ElementPairA, + StrideA, + ElementPairB, + StrideB, TiledMma, - GmemTiledCopyAPair, - SmemLayoutAtomAPair, - SmemCopyAtomAPair, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, cute::identity, - GmemTiledCopyBPair, - SmemLayoutAtomBPair, - SmemCopyAtomBPair, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, cute::identity >; @@ -483,7 +514,9 @@ struct CollectiveBuilder< cute::enable_if_t< cute::is_same_v || cute::is_same_v || - cute::is_same_v> + cute::is_same_v || + cute::is_same_v || + cute::is_same_v> > { static_assert(is_static::value); static_assert(is_static::value); @@ -495,13 +528,16 @@ struct CollectiveBuilder< static_assert(!detail::is_use_rmem_A(), "Not supported for fp8 non-TN warp specialized kernels yet\n"); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - using AtomLayoutMNK = cute::conditional_t, + static constexpr bool IsArrayOfPointersGemm = (cute::is_same_v || + cute::is_same_v); + using AtomLayoutMNK = cute::conditional_t || + IsArrayOfPointersGemm, Layout>, Layout>>; using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< @@ -517,8 +553,9 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized< - PipelineStages, ClusterShape_MNK, KernelScheduleType>; + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; @@ -580,7 +617,7 @@ struct CollectiveBuilder< static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif // For fp32 types, map to tf32 MMA value type @@ -721,7 +758,7 @@ struct CollectiveBuilder< static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif // For fp32 types, map to tf32 MMA value type @@ -819,7 +856,7 @@ struct CollectiveBuilder< static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif // For fp32 types, map to tf32 MMA value type @@ -918,13 +955,20 @@ struct CollectiveBuilder< static_assert(is_static::value); static_assert(is_static::value); #ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif -static constexpr bool IsTmaCompatible = detail::is_aligned< - ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(); +using ExtractedElementA = detail::deduce_mixed_width_dtype_t<0, ElementA>; +using ExtractedElementB = detail::deduce_mixed_width_dtype_t<0, ElementB>; -static constexpr bool IsMixedWidthInput = cute::sizeof_bits_v != cute::sizeof_bits_v; +static constexpr bool IsTmaCompatible = detail::is_aligned< + ExtractedElementA, AlignmentA, ExtractedElementB, AlignmentB, detail::tma_alignment_bytes>(); + +// Users opt into scales via the builder by passing a tuple of Elements for the input that will be scaled. We detect +// scale support if ONLY one of the inputs have tuples to describe them. +static constexpr bool OnlyOneIsTuple = cute::is_tuple::value ^ cute::is_tuple::value; +static constexpr bool IsDifferentWidth = sizeof_bits::value != sizeof_bits::value; +static constexpr bool IsMixedWidthInput = IsDifferentWidth || (IsDifferentWidth && OnlyOneIsTuple); #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) // Persistent schedules perform best for CUDA Toolkits with version >= 12.1 diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 985e0ecc..bf26de71 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -56,7 +56,7 @@ template < class TransformB > struct CollectiveMma { - static_assert(cutlass::detail::dependent_false == 0, "Could not find a mainloop specialization."); + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -73,5 +73,6 @@ struct CollectiveMma { #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 00000000..daf07c4c --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,741 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v; + using StridesA = cute::conditional_t; + using StridesB = cute::conditional_t; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StridesA dA; + ElementB const** ptr_B; + StridesB dB; + }; + + // Device side kernel params + struct Params { + TMA_A tma_load_a; + TMA_B tma_load_b; + void* tensormaps; + InternalElementA const** ptr_A; + StridesA dA; + InternalElementB const** ptr_B; + StridesB dB; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(0), 1); + auto [M,N,K,L] = problem_shape_MNKL; + const uint32_t mock_L = 1; + + // These tensor pointers are only used to create tensormap/tma desc. + // This address to the tensor will be replaced with correct address before the initial tma load + InternalElementA const* ptr_A_first_batch = reinterpret_cast(args.ptr_A); + InternalElementB const* ptr_B_first_batch = reinterpret_cast(args.ptr_B); + cudaError_t cuda_error = cudaGetLastError(); // clear previous error + + StrideA stride_a; + StrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless + stride_a = StrideA{}; + stride_b = StrideB{}; + } + else { + stride_a = args.dA; + stride_b = args.dB; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), stride_b)); + TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + void* tensormaps = workspace; + + return { + tma_load_a, + tma_load_b, + tensormaps, + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + bool implementable = true; + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8; + + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class TensorA, class TensorB, + class TensorMapA, class TensorMapB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (warp_idx_in_warp_group == 0 and lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(gA_tensormap)); + copy(recast(pB_tensormap), recast(gB_tensormap)); + } + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Bringing tensormaps to smem (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_fetch_to_smem( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) const { + Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(gA_tensormap), recast(sA_tensormap)); + copy(recast(gB_tensormap), recast(sB_tensormap)); + + cp_async_fence(); + cp_async_wait<0>(); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Only consider dimensions and strides that we need to recalculate and replace for each group + constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N + static_assert(TensorRank == Int<3>{}, + "Descriptor modification for global dims & strides expects rank as 3."); + cute::array prob_shape_A = {1,1,1}; + cute::array prob_stride_A = {0,0,0}; + cute::array prob_shape_B = {1,1,1}; + cute::array prob_stride_B = {0,0,0}; + + InternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + InternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormap.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Bringing tensormaps to smem + tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormap, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) { + // Entire warp must do this (ie its aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); + } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index e9867351..dc0a5e9d 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -134,9 +134,7 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - using MainloopPipeline = cutlass::PipelineTmaAsync< - DispatchPolicy::Stages, - typename DispatchPolicy::ClusterShape>; + using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineState = cutlass::PipelineState; using PipelineParams = typename MainloopPipeline::Params; @@ -336,22 +334,20 @@ struct CollectiveMma< /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { + static void prefetch_tma_descriptors(Params const& mainloop_params) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } /// Set up the data needed by this collective for load and mma. /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// that the tuple must contain at least two elements, with the first two elements being: + /// Returned tuple must contain at least two elements, with the first two elements being: /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) /// The rest of the tensors can be specified as needed by this collective. - template + template CUTLASS_DEVICE auto - tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { using X = Underscore; // Separate out problem shape for convenience auto [M,N,K,L] = problem_shape_MNKL; @@ -362,8 +358,8 @@ struct CollectiveMma< Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) return cute::make_tuple(gA_mkl, gB_nkl); } @@ -379,7 +375,7 @@ struct CollectiveMma< Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& tiled_tensors, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, @@ -405,16 +401,16 @@ struct CollectiveMma< constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gA_mkl = get<0>(tiled_tensors); - Tensor gB_nkl = get<1>(tiled_tensors); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 4d50793e..2b8e92a2 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -49,6 +49,8 @@ #include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" #include "cutlass/trace.h" +#include "cutlass/detail/collective.hpp" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { @@ -63,33 +65,33 @@ template < class KernelSchedule, class TileShape_, class ElementAOptionalTuple, - class StrideAOptionalTuple, + class StrideA_, class ElementBOptionalTuple, - class StrideBOptionalTuple, + class StrideB_, class TiledMma_, - class GmemTiledCopyAOptionalTuple, - class SmemLayoutAtomAOptionalTuple, - class SmemCopyAtomAOptionalTuple, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, class TransformA_, - class GmemTiledCopyBOptionalTuple, - class SmemLayoutAtomBOptionalTuple, - class SmemCopyAtomBOptionalTuple, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput, TileShape_, ElementAOptionalTuple, - StrideAOptionalTuple, + StrideA_, ElementBOptionalTuple, - StrideBOptionalTuple, + StrideB_, TiledMma_, - GmemTiledCopyAOptionalTuple, - SmemLayoutAtomAOptionalTuple, - SmemCopyAtomAOptionalTuple, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, TransformA_, - GmemTiledCopyBOptionalTuple, - SmemLayoutAtomBOptionalTuple, - SmemCopyAtomBOptionalTuple, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> { private: @@ -104,6 +106,17 @@ private: } } + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + public: // // Type Aliases @@ -111,30 +124,53 @@ public: using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; using TileShape = TileShape_; - using ElementA = ElementAOptionalTuple; - using ElementB = ElementBOptionalTuple; - static constexpr bool IsATransformed = cute::sizeof_bits_v < cute::sizeof_bits_v; - using ElementScale = void; + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale], [ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using StrideB = StrideB_; + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the code to compile when the scale is void. + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || + (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); - using StrideA = StrideAOptionalTuple; - using StrideB = StrideBOptionalTuple; - using StrideScale = void; - static constexpr int AlignmentScale = cute::Int<0>{}; using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyAOptionalTuple; - using GmemTiledCopyB = GmemTiledCopyBOptionalTuple; - using GmemTiledCopyScale = void; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; - using SmemLayoutAtomA = SmemLayoutAtomAOptionalTuple; - using SmemLayoutAtomB = SmemLayoutAtomBOptionalTuple; - using SmemLayoutAtomScale = void; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + // Scale layout atom set after swapping. - using SmemCopyAtomA = SmemCopyAtomAOptionalTuple; - using SmemCopyAtomB = SmemCopyAtomBOptionalTuple; - using SmemCopyAtomScale = void; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8, Fp8 WGMMA) static constexpr bool IsLayoutAkBmn = @@ -167,22 +203,20 @@ public: using InternalTransformA = cute::conditional_t; using InternalTransformB = cute::conditional_t; - static_assert(sizeof(InternalElementB) == 2 || - (cute::is_same_v, layout::RowMajor> && - cute::is_same_v, layout::ColumnMajor>), - "B operand after swap must be 2 bytes OR K-major."); static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; using TmaElementA = cute::conditional_t; using ArchTag = typename DispatchPolicy::ArchTag; using MainloopPipeline = cutlass::PipelineTmaAsync< - DispatchPolicy::Stages, - typename DispatchPolicy::ClusterShape>; + DispatchPolicy::Stages>; using PipelineState = cutlass::PipelineState; using PipelineParams = typename MainloopPipeline::Params; + using SmemLayoutAtomScale = Layout(InternalSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -191,8 +225,10 @@ public: static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert(cute::is_same_v || cute::is_same_v, - "The TMA mcast for A must match the mcast for scales or the scale tiled copy must be void."); + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( InternalSmemLayoutAtomA{}, @@ -202,6 +238,12 @@ public: InternalSmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape( + SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,NonVoidStrideScale>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major only (e.g. tf32, fp32, fp8, int8). static constexpr bool IsLayoutAmnBmn = @@ -230,23 +272,114 @@ public: make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // These two restrictions are related, so we place the assertions together. + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + // NOTE: Deleting this assertion without required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); static_assert(TransposeB xor (cute::is_same_v), "Should be same layout if not TransposeB."); static_assert(!TransposeB || size<1>(SmemLayoutB{}) * cute::sizeof_bits_v / 8 == 128, "SmemLayoutB K must be 128bytes to be transposed."); - + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + static constexpr auto + elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + return cute::cosize_v; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + static constexpr auto + elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale ) { + return 0; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes() { + constexpr uint32_t a_bytes = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v) / 8); + constexpr uint32_t b_bytes = (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v) / 8); + + constexpr uint32_t baseline_bytes = a_bytes + b_bytes; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = (size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v) / 8); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = (size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v) / 8); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + +public: static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); struct SharedStorage { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); struct TensorStorage : cute::aligned_struct { - cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_A; cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; } tensors; using PipelineStorage = typename MainloopPipeline::SharedStorage; @@ -261,6 +394,10 @@ public: StrideA dA{}; ElementB const* ptr_B = nullptr; StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; uint32_t mma_promotion_interval = 4; }; @@ -268,13 +405,13 @@ public: struct Params { private: using Outer = CollectiveMma; + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, + TransformB_>; public: // Assumption: StrideA is congruent with Problem_MK @@ -284,6 +421,21 @@ public: SmemLayoutA{}(_,_,cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy( + GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, @@ -293,6 +445,10 @@ public: size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any TMA_A tma_load_a; TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; }; // @@ -339,16 +495,50 @@ public: SmemLayoutA{}(_,_,cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_,_,cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return { - tma_load_a, - tma_load_b - }; + + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0 }; + } + else if constexpr (ModeHasScales) { + auto scale_k = (K + args.group_size - 1) / args.group_size; + ElementScale const* ptr_S = args.ptr_S; + StrideScale dS = args.dS; + Tensor tensor_scale = make_tensor(get_logical_ptr(ptr_S), make_layout(make_shape(M,scale_k,L), dS)); + tma_load_scale = make_tma_copy( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size }; + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M,scale_k,L), dS)); + tma_load_zero = make_tma_copy( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + _1{}); // mcast along N mode for this M load, if any + return { tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size }; + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } } template @@ -366,6 +556,35 @@ public: constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + const int scale_mn = SwapAB ? N : M; + const int scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(scale_mn,scale_k,L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); } @@ -373,27 +592,39 @@ public: } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr uint32_t TmaTransactionBytes = - (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v) / 8)+ - (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v) / 8); + static constexpr uint32_t TmaTransactionBytes = compute_tma_transaction_bytes(); /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } /// Set up the data needed by this collective for load and mma. /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// that the tuple must contain at least two elements, with the first two elements being: + /// Returned tuple must contain at least two elements, with the first two elements being: /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) /// The rest of the tensors can be specified as needed by this collective. - template + template CUTLASS_DEVICE auto - tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { using X = Underscore; // Separate out problem shape for convenience auto [M,N,K,L] = problem_shape_MNKL; @@ -404,16 +635,38 @@ public: Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - return cute::make_tuple(gA_mkl, gB_nkl); + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } + else if constexpr (ModeHasScales) { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M,scale_k,L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_,_)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } } /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective + /// This overload gets triggered when we have scales. template < - class TensorA, class TensorB, + class... Ts, class KTileIterator, class BlockCoord > CUTLASS_DEVICE void @@ -421,7 +674,7 @@ public: Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& tiled_tensors, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, @@ -429,43 +682,58 @@ public: TensorStorage& shared_tensors) { using namespace cute; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs"); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs"); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + int warp_idx = canonical_warp_idx_sync(); int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); if (warp_idx_in_warp_group == 0 and lane_predicate) { - Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) - Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) // - // Prepare the TMA loads for A and B + // Prepare the TMA loads for A, B and Scales // constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gA_mkl = get<0>(tiled_tensors); - Tensor gB_nkl = get<1>(tiled_tensors); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) uint16_t mcast_mask_a = 0; uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; // Issue TmaLoads // Maps the tile -> block, value @@ -483,6 +751,8 @@ public: } } + auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + // Mainloop CUTLASS_PRAGMA_NO_UNROLL for ( ; k_tile_count > 0; --k_tile_count) { @@ -499,6 +769,38 @@ public: int write_stage = smem_pipe_write.index(); copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } + else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes + // on the fly. + // We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K + // is a multiple of the threadblock tile K + const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + ++k_tile_iter; // Advance smem_pipe_write @@ -575,7 +877,7 @@ public: // Allocate fragments and descriptors Tensor tCrA_mma = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrA_load = make_fragment_like(tCrA_mma); + Tensor tCrA_load = make_fragment_like(tCrA_mma); Tensor tCsB = thread_mma.partition_B(gmma_sB_position_dependent); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) @@ -583,11 +885,21 @@ public: // // Copy Atom A retiling // - auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Compute the max vector length that can be used to copy A. This will match the vector width of the + // conversions used. It helps by allowing the compiler to convert using the same register that was used + // to load the data from smem. This significantly reduces the need to move data among registers. + // Note that this is correct even if copy fails to vectorize, since the granularity at which we perform + // the conversion does not impact correctness. + using A_CPY_VEC = decltype(max_common_vector(tCsA, tCrA_copy_view)); + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = partition_extra_mma_info(thread_mma, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K @@ -626,16 +938,19 @@ public: barrier_token = pipeline.consumer_try_wait(smem_pipe_read); // copy smem->rmem for A operand - copy(smem_tiled_copy_A, tCsA(_,_,0,read_stage), tCrA_copy_view(_,_,0)); - transform_internal_A(tCrA_load(_, _, 0), tCrA_mma(_, _, 0)); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, read_stage); + + transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); // transpose B operand in SMEM transpose(sB, gmma_sB, read_stage, 0); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA_load) - 1; ++k_block) { - copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); - transform_internal_A(tCrA_load(_, _, k_block + 1), tCrA_mma(_, _, k_block + 1)); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); + transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); transpose.synchronize(k_block); transpose(sB, gmma_sB, read_stage, k_block + 1); warpgroup_arrive(); @@ -650,8 +965,9 @@ public: --k_tile_count; if (k_tile_count > 0) { pipeline.consumer_wait(smem_pipe_read, barrier_token); - copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); - transform_internal_A(tCrA_load(_, _, 0), tCrA_mma(_, _, 0)); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); transpose(sB, gmma_sB, smem_pipe_read.index(), 0); } warpgroup_arrive(); @@ -688,14 +1004,16 @@ public: } if (k_block == size<2>(tCrA_load) - 1) { pipeline.consumer_wait(smem_pipe_read, barrier_token); - copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); - transform_internal_A(tCrA_load(_, _, 0), tCrA_mma(_, _, 0)); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); + transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, 0); // transpose B operand in SMEM transpose(sB, gmma_sB, smem_pipe_read.index(), 0); } else { - copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); - transform_internal_A(tCrA_load(_, _, k_block + 1), tCrA_mma(_, _, k_block + 1)); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); + transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); // transpose B operand in SMEM transpose.synchronize(k_block); // make transpose of k_block available transpose(sB, gmma_sB, read_stage, k_block + 1); @@ -732,8 +1050,9 @@ public: CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA_load) - 1; ++k_block) { - copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); - transform_internal_A(tCrA_load(_, _, k_block + 1), tCrA_mma(_, _, k_block + 1)); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA, tCrA_copy_view, + partitioned_extra_info, copy_partitions_extra_info, k_block + 1, read_stage); + transform_A_kblock(tCrA_load, A_CPY_VEC{}, tCrA_mma, partitioned_extra_info, k_block + 1); transpose.synchronize(k_block); // make k_block transpose available transpose(sB, gmma_sB, read_stage, k_block + 1); warpgroup_arrive(); @@ -779,48 +1098,314 @@ public: } private: - template - static constexpr bool - is_fast_converter_exact() { - using DstType = typename Converter::result_type; - using SrcType = typename Converter::source_type; - - constexpr bool IsIntToFP32Exact = cute::is_same_v && - (cute::numeric_limits::is_integer && cute::sizeof_bits_v <= 16); - - constexpr bool IsIntToFP16orBF16Exact = (cute::is_same_v || cute::is_same_v) && - (cute::numeric_limits::is_integer && cute::sizeof_bits_v <= 8); + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE + auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { - return IsIntToFP32Exact || IsIntToFP16orBF16Exact; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::tuple{}; + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } } - template > - CUTLASS_DEVICE void - transform_internal_A(Tensor&& in, Tensor&& out) { + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE + auto partition_extra_mma_info( + ThreadMma const& thread_mma, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // noting to do + return cute::tuple{}; + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = thread_mma.partition_A(sS); + Tensor tCrS = make_fragment_like(thread_mma.partition_fragment_A(sS(_,_,Int<0>{}))); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = thread_mma.partition_A(sZ); + Tensor tCrZ = make_fragment_like(thread_mma.partition_fragment_A(sZ(_,_,Int<0>{}))); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // noting to do + return cute::tuple{}; + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Utilities to copy A and extra inputs from smem to RF + template + CUTLASS_DEVICE + void copy_A_and_extra_info( + SmemTiledCopyA const& smem_tiled_copy_A, + TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, + int k_block, + int read_stage) { + + copy(smem_tiled_copy_A, tCsA(_,_,k_block,read_stage), tCrA_copy_view(_,_,k_block)); + + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_,_,k_block,read_stage), tCrS_copy_view(_,_,k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_,_,k_block,read_stage), tCrZ_copy_view(_,_,k_block)); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + } + + /// Utilities to transform A. + template + CUTLASS_DEVICE + void transform_A_kblock( + TCrA_load const& tCrA_load, + cute::Int vec_A, + TCrA_mma& tCrA_mma, + cute::tuple const& partitioned_extra_info, + int const k_block) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + transform_internal_A(tCrA_load(_, _, k_block), vec_A, tCrA_mma(_, _, k_block)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + auto tCrS = cute::get<1>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), vec_A, make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrA_mma(_, _, k_block)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrZ = cute::get<3>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), + vec_A, + make_fragment_like(tCrA_mma)(_, _, k_block), + tCrS(_, _, 0), + tCrZ(_, _, 0), + make_fragment_like(tCrZ)(_, _, 0), + tCrA_mma(_, _, k_block)); + } + else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + + /// Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void + convert_tensor( + Tensor const& in, + Tensor& out, + cute::Int width = {}) { + /// This is an element-wise conversion where we expect both tensors to have the same layout. /// As a result, we can cast as a cutlass array to use the fast numeric converters without - /// worrying about indexing into the layout. + /// worrying about indexing into the layout. + constexpr int N = cosize_v; - /// The inputs must be backed by registers & be statically sized so we can unroll the conversion loops. + /// The inputs must be backed by registers & be statically sized. static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); - static_assert(cute::is_same_v, "Input engine must be same type as the A operand"); - static_assert(cute::is_same_v, "Output engine must be same type as the Mma input"); static_assert(is_static_v, "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), "Cosize and size of the layout must be equal."); + static_assert(N % ConversionVectorWidth == 0, "Conversion vector width must divide cosize of the tensor layout."); - using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - using DefaultConverterAToB = cutlass::NumericArrayConverter; - using FastConverterAToB = cutlass::FastNumericArrayConverter; + using Converter = cutlass::NumericArrayConverter; - using ConverterAToB = cute::conditional_t(), FastConverterAToB, DefaultConverterAToB>; + constexpr int NumIterations = N / ConversionVectorWidth; - SrcArray* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())); - DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())); - *dst_array_ptr = std::move(ConverterAToB::convert(*src_array_ptr)); + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } } + + template + CUTLASS_DEVICE void + transform_internal_A( + Tensor&& in, + cute::Int a_vec_width, + Tensor&& out) { + + convert_tensor(in, out, a_vec_width); + } + + template + CUTLASS_DEVICE void + transform_internal_A( + Tensor&& in, + cute::Int a_vec_width, + Tensor&& converted_inputs, + Tensor&& scales, + Tensor&& out) { + + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + + // Apply scales and broadcast across inputs, store in converted_inputs + cute::transform(converted_inputs, scales, converted_inputs, cute::multiplies{}); + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } + + template + CUTLASS_DEVICE void + transform_internal_A( + Tensor&& in, + cute::Int a_vec_width, + Tensor&& converted_inputs, + Tensor&& scales, + Tensor&& zeros, + Tensor&& converted_zeros, + Tensor&& out) { + + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + static_assert(cute::is_same_v, + "Type of the engine zero buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + convert_tensor(zeros, converted_zeros); + + // Apply scales and broadcast across inputs, store in converted_inputs + cute::transform(converted_inputs, scales, converted_inputs, cute::multiplies{}); + cute::transform(converted_inputs, converted_zeros, converted_inputs, cute::plus{}); + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index 38e6ca1d..3b0336bf 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -106,9 +106,7 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - using MainloopPipeline = cutlass::PipelineTmaAsync< - DispatchPolicy::Stages, - typename DispatchPolicy::ClusterShape>; + using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename cutlass::PipelineState; @@ -147,8 +145,7 @@ struct CollectiveMma< using InternalElementA = cute::conditional_t>>; using InternalElementB = cute::conditional_t>>; - struct SharedStorage - { + struct SharedStorage { cute::array_aligned> smem_A; cute::array_aligned> smem_B; @@ -244,13 +241,13 @@ struct CollectiveMma< /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { + static void prefetch_tma_descriptors(Params const& mainloop_params) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective template < class TensorA, class TMA_LOAD_A, class TensorB, class TMA_LOAD_B, @@ -281,8 +278,8 @@ struct CollectiveMma< "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); SharedStorage& storage = *reinterpret_cast(shared_memory); - Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(storage.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(storage.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // // Prepare the TMA loads for A and B @@ -295,11 +292,11 @@ struct CollectiveMma< auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) // // Prepare TMA membars and PREFETCH @@ -335,9 +332,7 @@ struct CollectiveMma< params.is_leader = warp_group_thread_idx == 0; params.num_consumers = NumThreadsPerWarpGroup; - MainloopPipeline pipeline( - storage.pipeline_storage, - params); + MainloopPipeline pipeline(storage.pipeline_storage, params, ClusterShape{}); // State variables used for iterating the circular buffer // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 1dffea46..90552862 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -107,9 +107,7 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - using MainloopPipeline = cutlass::PipelineTmaAsync< - DispatchPolicy::Stages, - typename DispatchPolicy::ClusterShape>; + using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineState = cutlass::PipelineState; using PipelineParams = typename MainloopPipeline::Params; @@ -255,22 +253,20 @@ struct CollectiveMma< /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { + static void prefetch_tma_descriptors(Params const& mainloop_params) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } /// Set up the data needed by this collective for load and mma. /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// that the tuple must contain at least two elements, with the first two elements being: + /// Returned tuple must contain at least two elements, with the first two elements being: /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) /// The rest of the tensors can be specified as needed by this collective. - template + template CUTLASS_DEVICE auto - tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { using X = Underscore; // Separate out problem shape for convenience auto [M,N,K,L] = problem_shape_MNKL; @@ -281,8 +277,8 @@ struct CollectiveMma< Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) return cute::make_tuple(gA_mkl, gB_nkl); } @@ -298,14 +294,12 @@ struct CollectiveMma< Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& tiled_tensors, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) - { - + TensorStorage& shared_tensors) { using namespace cute; int warp_idx = canonical_warp_idx_sync(); int warp_idx_in_warp_group = warp_idx % 4; @@ -322,16 +316,16 @@ struct CollectiveMma< constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gA_mkl = get<0>(tiled_tensors); - Tensor gB_nkl = get<1>(tiled_tensors); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); // Partition the inputs based on the current block coordinates. auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -386,10 +380,7 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster CUTLASS_DEVICE void - load_tail( - MainloopPipeline pipeline, - PipelineState smem_pipe_write) - { + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { int warp_idx = canonical_warp_idx_sync(); int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); @@ -418,8 +409,7 @@ struct CollectiveMma< int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) - { + Params const& mainloop_params) { using namespace cute; static_assert(is_rmem::value, "C tensor must be rmem resident."); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index 09f97f8e..301cb1e0 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -108,9 +108,7 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - using MainloopPipeline = cutlass::PipelineTmaAsync< - DispatchPolicy::Stages, - typename DispatchPolicy::ClusterShape>; + using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineState = cutlass::PipelineState; using PipelineParams = typename MainloopPipeline::Params; @@ -261,14 +259,12 @@ struct CollectiveMma< /// Set up the data needed by this collective for load and mma. /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// that the tuple must contain at least two elements, with the first two elements being: + /// Returned tuple must contain at least two elements, with the first two elements being: /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - /// The rest of the tensors can be specified as needed by this collective. - template + template CUTLASS_DEVICE auto - tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { using X = Underscore; // Separate out problem shape for convenience auto [M,N,K,L] = problem_shape_MNKL; @@ -279,8 +275,8 @@ struct CollectiveMma< Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) return cute::make_tuple(gA_mkl, gB_nkl); } @@ -296,7 +292,7 @@ struct CollectiveMma< Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& tiled_tensors, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, @@ -320,8 +316,8 @@ struct CollectiveMma< constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - Tensor gA_mkl = get<0>(tiled_tensors); - Tensor gB_nkl = get<1>(tiled_tensors); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 33a2958c..94e7fa88 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -42,6 +42,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/detail/layout.hpp" #include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" @@ -106,9 +107,12 @@ public: using LayoutC = gemm::detail::StrideToLayoutTagC_t; using LayoutD = gemm::detail::StrideToLayoutTagC_t; - // NOTE: 3.0 kernels do not support complex transforms for now ... - static ComplexTransform const kTransformA = ComplexTransform::kNone; - static ComplexTransform const kTransformB = ComplexTransform::kNone; + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 using MathOperator = cutlass::arch::OpMultiplyAdd; @@ -141,17 +145,17 @@ public: static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; // Warp shape is not a primary API type in 3.x - // But we can best approximate it by inspecting the TiledMma::TiledShape_MNK + // But we can best approximate it by inspecting the TiledMma // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads - static constexpr int WarpsInMma = cute::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); static constexpr int WarpsInMmaM = 4; static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); using WarpCount = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape< - cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM, - cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN, - cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>; + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; @@ -270,7 +274,12 @@ public: /// Initializes GEMM state from arguments. Status - initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); @@ -283,20 +292,33 @@ public: // Initialize the Params structure params_ = GemmKernel::to_underlying_arguments(args, workspace); - // account for dynamic smem capacity if needed - int smem_size = GemmKernel::SharedStorageSize; - if (smem_size >= (48 << 10)) { - CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); - cudaError_t result = cudaFuncSetAttribute( - device_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); - return Status::kErrorInternal; + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } } } + return Status::kSuccess; } @@ -317,7 +339,7 @@ public: /// Primary run() entry point API that is static allowing users to create and manage their own params. /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() static Status - run(Params& params, cudaStream_t stream = nullptr) { + run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversal::run()"); dim3 const block = GemmKernel::get_block_shape(); dim3 const grid = get_grid_shape(params); @@ -331,13 +353,54 @@ public: dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); - void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; - launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + launch_result = cuda_adapter->launch( + grid, cluster, block, smem_size, stream, kernel_params, 0 + ); + } + else { + return Status::kErrorInternal; + } + } + else { + + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + + } } else { launch_result = Status::kSuccess; - device_kernel<<>>(params); + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + device_kernel<<>>(params); + } } cudaError_t result = cudaGetLastError(); @@ -356,18 +419,27 @@ public: /// Launches the kernel after first constructing Params internal state from supplied arguments. Status - run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr + ) { Status status = initialize(args, workspace, stream); if (Status::kSuccess == status) { - status = run(params_, stream); + status = run(params_, stream, cuda_adapter); } return status; } /// Launches the kernel after first constructing Params internal state from supplied arguments. Status - operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { - return run(args, workspace, stream); + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(args, workspace, stream, cuda_adapter); } /// Overload that allows a user to re-launch the same kernel without updating internal params struct. @@ -387,7 +459,7 @@ public: ////////////////////////////// CUTLASS 2.x API ///////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -template +template class GemmUniversalAdapter< GemmKernel_, cute::enable_if_t::value>> @@ -501,9 +573,14 @@ public: } /// Initializes GEMM state from arguments. - Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr + ) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream, cuda_adapter); } /// Lightweight update given a subset of arguments. @@ -513,13 +590,18 @@ public: } /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) { + Status run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { - return underlying_operator_.run(stream); + return underlying_operator_.run(stream, cuda_adapter); } /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) { + Status operator()( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(stream); } @@ -527,12 +609,13 @@ public: Status operator()( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr) { + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { - Status status = initialize(args, workspace, stream); + Status status = initialize(args, workspace, stream, cuda_adapter); if (status == Status::kSuccess) { - status = run(stream); + status = run(stream, cuda_adapter); } return status; diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 265eedfd..0a8f75ee 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -46,6 +46,7 @@ #include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/device_kernel.h" +#include "cutlass/cuda_host_adapter.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/kernel/gemm_universal.h" @@ -69,6 +70,8 @@ class GemmUniversalBase { public: using GemmKernel = GemmKernel_; + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + using ThreadblockShape = typename GemmKernel::Mma::Shape; using ElementA = typename GemmKernel::ElementA; @@ -295,7 +298,8 @@ public: Status initialize( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr) + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); @@ -323,9 +327,8 @@ public: return Status::kSuccess; } - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); @@ -339,13 +342,27 @@ public: "block: (" << block << "), " "SMEM: (" << smem_size_ << ")"); - Kernel2<<>>(params_); + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms_}; + return cuda_adapter->launch(grid, block, smem_size_, stream, kernel_params, 0); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); - // Query for errors - cudaError_t result = cudaGetLastError(); - if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; + Kernel2<<>>(params_); + + // Query for errors + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } } return Status::kSuccess; @@ -363,12 +380,13 @@ public: Status operator()( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr) + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { - status = run(stream); + status = run(stream, cuda_adapter); } return status; diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index ea526115..9c871e4b 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -53,6 +53,8 @@ struct KernelTma { }; struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperative { }; +struct KernelArrayTmaWarpSpecializedCooperative { }; +struct KernelGroupTmaWarpSpecializedCooperative { }; ////////////////////////////////////////////////////////////////////////////// @@ -65,6 +67,8 @@ struct KernelTmaWarpSpecializedCooperative { }; struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; +struct KernelArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelArrayTmaWarpSpecializedCooperative { }; +struct KernelGroupTmaWarpSpecializedCooperativeFP8FastAccum : KernelGroupTmaWarpSpecializedCooperative { }; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -225,6 +229,23 @@ struct MainloopSm90TmaGmmaWarpSpecializedFP8 "KernelSchedule must be one of the warp specialized policies"); }; +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelGroupTmaWarpSpecializedCooperative +> +struct MainloopSm90ArrayTmaGmmaWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; + static_assert( + cute::is_base_of_v || + cute::is_base_of_v, + "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); +}; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/gemm_enumerated_types.h b/include/cutlass/gemm/gemm_enumerated_types.h index 25e12404..a03e0c5d 100644 --- a/include/cutlass/gemm/gemm_enumerated_types.h +++ b/include/cutlass/gemm/gemm_enumerated_types.h @@ -69,6 +69,7 @@ enum class GemmUniversalMode { kGemmSplitKParallel, kBatched, kArray, + kGrouped, kInvalid }; diff --git a/include/cutlass/gemm/group_array_problem_shape.hpp b/include/cutlass/gemm/group_array_problem_shape.hpp new file mode 100644 index 00000000..853ee449 --- /dev/null +++ b/include/cutlass/gemm/group_array_problem_shape.hpp @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing problem shapes + for 3.x Ptr-Array GEMMs and Grouped GEMMs. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_coord.h" + +#include "cute/container/array.hpp" + +#if ! defined(__CUDACC_RTC__) +#include +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GroupProblemShape { + using UnderlyingProblemShape = ProblemShape_; + int32_t num_groups = 1; + UnderlyingProblemShape* problem_shapes = nullptr; + UnderlyingProblemShape const* host_problem_shapes = nullptr; + + CUTLASS_HOST_DEVICE + int32_t groups() const { return num_groups; } + + CUTLASS_HOST_DEVICE + UnderlyingProblemShape const + get_problem_shape(int32_t group_idx) const { + return problem_shapes[group_idx]; + } + + CUTLASS_HOST_DEVICE + UnderlyingProblemShape const + get_host_problem_shape(int32_t group_idx) const { + return host_problem_shapes[group_idx]; + } +}; + +template +class ArrayProblemShape { +public: + using UnderlyingProblemShape = ProblemShape_; + + ArrayProblemShape() = default; + ArrayProblemShape(UnderlyingProblemShape ps) : problem_shape_(ps) {} + + // Num of groups for Ptr-Array GEMM always remain one, just the number of batches (l) can vary + // This is just to maintain uniformity with GroupProblemShape + constexpr int32_t groups() const { return 1; } + + UnderlyingProblemShape* problem_shapes() const { + return &problem_shape_; + } + UnderlyingProblemShape const* host_problem_shapes() const { + return &problem_shape_; + } + + // This is just to maintain uniformity with GroupProblemShape + CUTLASS_HOST_DEVICE + UnderlyingProblemShape const + get_problem_shape(int32_t /* unused */ = 0) const { + return problem_shape_; + } + + CUTLASS_HOST_DEVICE + UnderlyingProblemShape const + get_host_problem_shape(int32_t /* unused */ = 0) const { + return problem_shape_; + } +private: + UnderlyingProblemShape problem_shape_{}; +}; + +} // namespace cutlass::gemm diff --git a/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h index 36f47c66..73ab1e9a 100644 --- a/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h @@ -851,7 +851,9 @@ protected: } ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D; if (ptr_Tensor) { - ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor; + ptr_Tensor = ReferenceFactory::add_pointer_offset( + ptr_Tensor, + tile_work.tiled_coord.k() * params.batch_stride_Tensor); } if (ptr_Vector) { ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector; @@ -2024,7 +2026,9 @@ protected: ptr_C += tile_work.tiled_coord.k() * params.batch_stride_C; ptr_D += tile_work.tiled_coord.k() * params.batch_stride_D; if (ptr_Tensor) { - ptr_Tensor += tile_work.tiled_coord.k() * params.batch_stride_Tensor; + ptr_Tensor = ReferenceFactory::add_pointer_offset( + ptr_Tensor, + tile_work.tiled_coord.k() * params.batch_stride_Tensor); } if (ptr_Vector) { ptr_Vector += tile_work.tiled_coord.k() * params.batch_stride_Vector; diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index f095bc53..c11ab114 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -67,9 +67,9 @@ class GemmUniversal< Epilogue_, ThreadblockSwizzle_, void, - // 3.x kernels use the first template argument to define the ProblemShape tuple + // 3.x kernels use the first template argument to define the ProblemShape // We use this invariant to SFINAE dispatch against either the 2.x API or the 3.x API - cute::enable_if_t::value> + cute::enable_if_t::value || IsCutlass3ArrayKernel::value)> > { public: diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 7bd98ce7..c6fc49b2 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -61,6 +61,19 @@ template < > class GemmUniversal; + +//////////////////////////////////////////////////////////////////////////////// + +// In cases where ProblemShape is not a tuple, this is used to check if the +// underlying problem shape type is aliased within or not. +// Used for dispatching GemmUniversal to 2.x API or 3.x API +template +struct IsCutlass3ArrayKernel : cute::false_type { }; + +template +struct IsCutlass3ArrayKernel> + : cute::true_type { }; + //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm::kernel @@ -75,4 +88,5 @@ class GemmUniversal; #include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h index 1c58b44e..e24d9acf 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -42,7 +42,7 @@ #include "cutlass/complex.h" #include "cutlass/semaphore.h" #include "cutlass/gemm/kernel/params_universal_base.h" - +#include "cutlass/subbyte_reference.h" #include "cutlass/trace.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -676,7 +676,9 @@ public: } ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; if (ptr_Tensor) { - ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; + ptr_Tensor = ReferenceFactory::add_pointer_offset( + ptr_Tensor, + threadblock_tile_offset.k() * params.batch_stride_Tensor); } if (ptr_Vector) { ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; @@ -1387,7 +1389,9 @@ public: ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; if (ptr_Tensor) { - ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; + ptr_Tensor = ReferenceFactory::add_pointer_offset( + ptr_Tensor, + threadblock_tile_offset.k() * params.batch_stride_Tensor); } if (ptr_Vector) { ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h index 55955d43..886f2884 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -217,6 +217,8 @@ public: // Only used by device-level operator GemmCoord *host_problem_sizes; + bool allow_early_exit; + // // Methods // @@ -235,7 +237,8 @@ public: ldb(nullptr), ldc(nullptr), ldd(nullptr), - host_problem_sizes(nullptr) + host_problem_sizes(nullptr), + allow_early_exit(false) { } @@ -256,7 +259,8 @@ public: typename LayoutB::Stride::LongIndex *ldb, typename LayoutC::Stride::LongIndex *ldc, typename LayoutC::Stride::LongIndex *ldd, - GemmCoord *host_problem_sizes=nullptr + GemmCoord *host_problem_sizes=nullptr, + bool allow_early_exit=false ): mode(mode), problem_sizes(problem_sizes), @@ -271,7 +275,8 @@ public: ldb(ldb), ldc(ldc), ldd(ldd), - host_problem_sizes(host_problem_sizes) + host_problem_sizes(host_problem_sizes), + allow_early_exit(allow_early_exit) { } @@ -303,6 +308,7 @@ public: typename LayoutC::Stride::LongIndex *ldc; typename LayoutC::Stride::LongIndex *ldd; + bool allow_early_exit; // // Methods @@ -318,7 +324,8 @@ public: lda(nullptr), ldb(nullptr), ldc(nullptr), - ldd(nullptr) + ldd(nullptr), + allow_early_exit(false) { } CUTLASS_HOST_DEVICE @@ -333,7 +340,8 @@ public: lda(args.lda), ldb(args.ldb), ldc(args.ldc), - ldd(args.ldd) + ldd(args.ldd), + allow_early_exit(args.allow_early_exit) { } @@ -388,6 +396,12 @@ public: CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Early exit following LAPACK's definition + if (params.allow_early_exit && + (params.output_op.alpha == ElementC(0)) && (params.output_op.beta == ElementC(1))) { + return; + } + // // Problem visitor. // diff --git a/include/cutlass/gemm/kernel/rank_2k_universal.h b/include/cutlass/gemm/kernel/rank_2k_universal.h index 2775710d..62a88d39 100644 --- a/include/cutlass/gemm/kernel/rank_2k_universal.h +++ b/include/cutlass/gemm/kernel/rank_2k_universal.h @@ -140,6 +140,8 @@ public: typename LayoutC::Stride::Index ldc; typename LayoutC::Stride::Index ldd; + bool allow_early_exit; + // // Methods // @@ -147,7 +149,8 @@ public: Arguments(): mode(GemmUniversalMode::kGemm), batch_count(1), - ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr), + allow_early_exit(false) { } /// constructs an arguments structure Arguments( @@ -166,7 +169,8 @@ public: typename LayoutA::Stride::Index lda, typename LayoutB::Stride::Index ldb, typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldd + typename LayoutC::Stride::Index ldd, + bool allow_early_exit = false ): mode(mode), problem_size(problem_size), @@ -174,7 +178,8 @@ public: epilogue(epilogue), ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), + allow_early_exit(allow_early_exit) { } @@ -231,6 +236,8 @@ public: int *semaphore; + bool allow_early_exit; + // // Methods // @@ -255,7 +262,8 @@ public: batch_stride_B(0), batch_stride_C(0), batch_stride_D(0), - semaphore(nullptr) { } + semaphore(nullptr), + allow_early_exit(false) { } CUTLASS_HOST_DEVICE Params( @@ -285,7 +293,8 @@ public: batch_stride_B(args.batch_stride_B), batch_stride_C(args.batch_stride_C), batch_stride_D(args.batch_stride_D), - semaphore(static_cast(workspace)) { + semaphore(static_cast(workspace)), + allow_early_exit(args.allow_early_exit) { } CUTLASS_HOST_DEVICE @@ -347,6 +356,12 @@ public: CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Early exit following LAPACK's definition + if (params.allow_early_exit && + (params.output_op.alpha == ElementC(0)) && (params.output_op.beta == ElementC(1))) { + return; + } + // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; diff --git a/include/cutlass/gemm/kernel/rank_k_universal.h b/include/cutlass/gemm/kernel/rank_k_universal.h index 188a4e70..7d2fb081 100644 --- a/include/cutlass/gemm/kernel/rank_k_universal.h +++ b/include/cutlass/gemm/kernel/rank_k_universal.h @@ -125,6 +125,8 @@ public: typename LayoutC::Stride::Index ldc; typename LayoutC::Stride::Index ldd; + bool allow_early_exit; + // // Methods // @@ -132,7 +134,8 @@ public: Arguments(): mode(GemmUniversalMode::kGemm), batch_count(1), - ptr_A(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + ptr_A(nullptr), ptr_C(nullptr), ptr_D(nullptr), + allow_early_exit(false) { } /// constructs an arguments structure Arguments( @@ -148,7 +151,8 @@ public: int64_t batch_stride_D, typename LayoutA::Stride::Index lda, typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldd + typename LayoutC::Stride::Index ldd, + bool allow_early_exit = false ): mode(mode), problem_size(problem_size), @@ -156,7 +160,8 @@ public: epilogue(epilogue), ptr_A(ptr_A), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), + allow_early_exit(allow_early_exit) { } @@ -196,6 +201,8 @@ public: int *semaphore; + bool allow_early_exit; + // // Methods // @@ -218,7 +225,8 @@ public: batch_stride_B(0), batch_stride_C(0), batch_stride_D(0), - semaphore(nullptr) { } + semaphore(nullptr), + allow_early_exit(false) { } CUTLASS_HOST_DEVICE Params( @@ -246,7 +254,8 @@ public: batch_stride_B(args.batch_stride_A), batch_stride_C(args.batch_stride_C), batch_stride_D(args.batch_stride_D), - semaphore(static_cast(workspace)) { + semaphore(static_cast(workspace)), + allow_early_exit(args.allow_early_exit) { } CUTLASS_HOST_DEVICE @@ -313,6 +322,12 @@ public: cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + // Early exit following LAPACK's definition + if (params.allow_early_exit && + (params.output_op.alpha == ElementC(0)) && (params.output_op.beta == ElementC(1))) { + return; + } + // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index e22fb436..e5fe6ec5 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -60,7 +60,7 @@ public: // using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); // Mainloop derived types @@ -101,7 +101,7 @@ public: sizeof(typename CollectiveMainloop::SharedStorage), sizeof(typename CollectiveEpilogue::SharedStorage))); - static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(cute::size(TiledMma{})); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Device side arguments @@ -141,8 +141,9 @@ public: static bool can_implement(Arguments const& args) { - return args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + bool mode_implementable = args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + return mode_implementable && TileScheduler::can_implement(args.scheduler); } static int diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp new file mode 100644 index 00000000..28ac4a0e --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -0,0 +1,749 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t || + cute::is_base_of_v> +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using Schedule = typename DispatchPolicy::Schedule; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + static_assert(cute::is_void_v, + "Ptr-Array Cooperative and Grouped Gemm Cooperative kernel only supports the default scheduler."); + + static constexpr bool IsGroupedGemmKernel = cute::is_base_of_v; + + using TileScheduler = cute::conditional_t::Scheduler, + typename detail::TileSchedulerSelector< + void, ArchTag, TileShape, ClusterShape>::Scheduler>; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + + struct TensorMapStorage : cute::aligned_struct<128> { + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + alignas(128) MainloopTensorMapStorage mainloop; + } tensormaps; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + KernelHardwareInfo hw_info; + TileSchedulerParams scheduler; + void* workspace; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + ProblemShape problem_shapes = args.problem_shape; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + } + else { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + } + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + hw_info, + scheduler, + workspace + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + bool implementable = true; + if constexpr (cute::is_base_of_v) { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + } else if constexpr (IsGroupedGemmKernel) { + // Group GEMM currently only supports rank-3 problem shapes + implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); + } + else { + implementable = false; + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN ? TileScheduler::RasterOrderOptions::AlongN : TileScheduler::RasterOrderOptions::AlongM; + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + else { + grid_shape = TileScheduler::get_grid_shape(params.problem_shape.get_host_problem_shape(), TileShape{}, ClusterShape{}, params.hw_info, args); + } + return grid_shape; + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + // Preconditions + static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Note: Tma Descriptor Prefetch (from either const or param) is not applicable here + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + // Purpose of maintaining this pipeline state is to make sure TMA loads have finished before doing descriptor updates + typename CollectiveMainloop::PipelineState mainloop_pipe_tma_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + return [] () { cute::cluster_wait(); }; + } + else { + __syncthreads(); + return [] () {}; // do nothing + } + } (); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx; + int32_t next_batch = curr_batch; + int32_t const mock_l_coord = 0; + int32_t const sm_idx = blockIdx.x + (blockIdx.y * gridDim.x); + int32_t const sm_count = params.hw_info.sm_count; + + // Fetch a copy of tensormaps for the CTA + auto input_tensormaps = collective_mainloop.tensormaps_init(params.mainloop, sm_count, sm_idx); + + // Update tensormap for the initial batch for the CTA + if (work_tile_info.is_valid()) { + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape_MNKL, + next_batch + ); + // Ensure warp is converged before issuing tensor replace + __syncwarp(); + // Entire warp must do this (ie its aligned) + collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); + } + + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + work_tile_info = fetch_next_work(work_tile_info, scheduler); + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, mock_l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.tensormaps_fence_acquire(input_tensormaps); + + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + input_tensormaps, + blk_coord, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + next_batch = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); // Usually just returns work_tile_info.L_idx + + if (work_tile_info.is_valid() && next_batch != curr_batch ) { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(next_batch), Int<1>{}); + } + // Wait for the last TMA stage to complete loading, before issuing tensormap updates + mainloop_pipe_tma_consumer_state.advance(work_k_tile_count-1); + mainloop_pipeline.consumer_wait(mainloop_pipe_tma_consumer_state); + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape_MNKL, + next_batch + ); + // Ensure warp is converged before issuing tensor replace + __syncwarp(); + // Entire warp must do this (ie its aligned) + collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); + curr_batch = next_batch; + // Advance the TMA consumer state for the last remaining stage that was being waited for above + mainloop_pipe_tma_consumer_state.advance(1); + } + else if (work_tile_info.is_valid()) { // case where batch/group didn't change between tiles + // Advance the TMA consumer state for all the stages to be in sync + mainloop_pipe_tma_consumer_state.advance(work_k_tile_count); + } + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + while (work_tile_info.is_valid()) { + if (!TileScheduler::requires_separate_reduction(params.scheduler)) { + load_order_barrier.wait(); + } + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() + ); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + cutlass::arch::warpgroup_reg_alloc(); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if(TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count + ); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() + ); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(params.problem_shape.get_problem_shape(work_tile_info.L_idx), Int<1>{}); + } + } // Scheduler work fetch loop + + if (do_store_tail) { + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state + ); + } + } // Consumer Warp Groups End + } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo + fetch_next_work( + typename TileScheduler::WorkTileInfo& work_tile_info, + TileScheduler& scheduler) const { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 14cb47e1..67f23afa 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -121,7 +121,7 @@ public: sizeof(typename CollectiveMainloop::SharedStorage), sizeof(typename CollectiveEpilogue::SharedStorage))); - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Device side arguments @@ -176,6 +176,8 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index 2ec1aa0e..582beee9 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -128,7 +128,7 @@ public: static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = 1; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Device side arguments @@ -183,6 +183,8 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; } @@ -270,7 +272,7 @@ public: mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; @@ -335,14 +337,14 @@ public: CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); - static_assert(cute::tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(tiled_tensors); - Tensor gB_nkl = get<1>(tiled_tensors); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); // Compute m_coord, n_coord, and l_coord with their post-tiled shapes auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); @@ -363,7 +365,7 @@ public: params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - tiled_tensors, + load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, @@ -378,8 +380,7 @@ public: if (collective_epilogue.is_producer_load_needed()) { // Ensure warp is converged before issuing epilogue loads __syncwarp(); - epi_load_pipe_producer_state = - collective_epilogue.load( + epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 551bb231..25d7711c 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -105,8 +105,8 @@ public: using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; /// Register requirement for Load and Math WGs @@ -203,6 +203,12 @@ public: workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); void* mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); return { args.mode, @@ -210,7 +216,7 @@ public: CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace), + scheduler, workspace }; } @@ -226,14 +232,17 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); return implementable; } static size_t get_workspace_size(Arguments const& args) { size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); @@ -247,11 +256,12 @@ public: Status status = Status::kSuccess; uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); status = TileScheduler::template initialize_workspace( - args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); if (status != Status::kSuccess) { return status; @@ -353,7 +363,7 @@ public: mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = size(TiledMma{}); mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; @@ -392,7 +402,7 @@ public: PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - auto cluster_wait_fn = [&] () { + auto cluster_wait_fn = [] () { // We need this to guarantee that the Pipeline init is visible // To all producers and consumer thread blocks in the Cluster if constexpr (size(ClusterShape{}) > 1) { @@ -420,14 +430,14 @@ public: CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); - static_assert(cute::tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(tiled_tensors); - Tensor gB_nkl = get<1>(tiled_tensors); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); // Get pipeline stage increments from tensor shapes auto k_tile_count = size<3>(gA_mkl); @@ -442,6 +452,11 @@ public: if (producer_warp_role == ProducerWarpRole::Mainloop) { bool do_load_order_arrive = true; while (work_tile_info.is_valid()) { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + work_tile_info = fetch_next_work(work_tile_info, scheduler); + continue; + } + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); @@ -457,7 +472,7 @@ public: params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - tiled_tensors, + load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, @@ -483,8 +498,10 @@ public: // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { - load_order_barrier.wait(); while (work_tile_info.is_valid()) { + if (!TileScheduler::requires_separate_reduction(params.scheduler)) { + load_order_barrier.wait(); + } if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); @@ -501,7 +518,8 @@ public: blk_coord, tiled_mma, lane_idx, - shared_storage.tensors.epilogue + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() ); } @@ -531,27 +549,27 @@ public: // // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if(TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) { + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); - collective_mainloop.mma( - mainloop_pipeline, - mainloop_pipe_consumer_state, - accumulators, - work_k_tile_count, - mma_thread_idx, - shared_storage.tensors.mainloop, - params.mainloop - ); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail( - mainloop_pipeline, - mainloop_pipe_consumer_state, - work_k_tile_count - ); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(work_k_tile_count); + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count + ); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } // Index of warp group within consumer warp groups int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; @@ -573,7 +591,8 @@ public: accumulators, tiled_mma, mma_thread_idx, - shared_storage.tensors.epilogue + shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx() ); epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index dc92e931..a48f218c 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -107,7 +107,7 @@ public: static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = 2; - static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; /// Register requirement for Load and Math WGs @@ -232,6 +232,8 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; } @@ -324,7 +326,7 @@ public: // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - + int thread_idx = int(threadIdx.x); int lane_idx = canonical_lane_idx(); int warp_idx = canonical_warp_idx_sync(); @@ -353,7 +355,7 @@ public: mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); // Epilogue Load pipeline using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; @@ -424,14 +426,14 @@ public: CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); - static_assert(cute::tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(tiled_tensors); - Tensor gB_nkl = get<1>(tiled_tensors); + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); // Get pipeline stage increments from tensor shapes auto k_tile_count = size<3>(gA_mkl); @@ -472,7 +474,7 @@ public: params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - tiled_tensors, + load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index 4c8901bc..c43b50bc 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -187,6 +187,8 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index d1c6dc84..c867b9d0 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -207,6 +207,8 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index 5a6571d8..403b24d1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -219,6 +219,8 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index ff64c14a..c9551ec1 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -76,6 +76,12 @@ public: is_final_split(uint32_t k_tiles_per_output_tile) const { return true; } + + CUTLASS_HOST_DEVICE + int32_t + reduction_subtile_idx() const { + return -1; + } }; using Params = PersistentTileSchedulerSm90Params; @@ -101,7 +107,8 @@ public: ClusterShape cluster_shape, [[maybe_unused]] KernelHardwareInfo const& hw_info, Arguments const& arguments, - [[maybe_unused]] void* workspace=nullptr) { + [[maybe_unused]] void* workspace=nullptr, + [[maybe_unused]] const uint32_t epilogue_subtile = 1) { // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic static_assert(cute::is_static::value); @@ -114,13 +121,19 @@ public: problem_blocks, to_gemm_coord(cluster_shape), hw_info, - arguments.max_swizzle_size, + arguments.max_swizzle_size, arguments.raster_order ); return params; } + CUTLASS_HOST_DEVICE + static bool + can_implement(Arguments const& args) { + return true; + } + CUTLASS_HOST_DEVICE PersistentTileSchedulerSm90() { }; @@ -164,7 +177,7 @@ public: scheduler_params.divmod_cluster_shape_major_, scheduler_params.divmod_cluster_shape_minor_, scheduler_params.divmod_cluster_blk_major_, - scheduler_params.log_swizzle_size_, + scheduler_params.log_swizzle_size_, scheduler_params.raster_order_); return {work_idx_m, work_idx_n, static_cast(work_idx_l), true}; @@ -180,11 +193,11 @@ public: static CUTLASS_DEVICE cute::tuple get_work_idx_m_and_n( - uint64_t blk_per_grid_dim, + uint64_t blk_per_grid_dim, FastDivmodU64Pow2 const& divmod_cluster_shape_major, FastDivmodU64Pow2 const& divmod_cluster_shape_minor, FastDivmodU64 const& divmod_cluster_blk_major, - int32_t log_swizzle_size, + int32_t log_swizzle_size, RasterOrder raster_order) { uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; @@ -199,26 +212,26 @@ public: } uint64_t cluster_idx_minor, cluster_idx_major; - + uint64_t cluster_idx_minor_div_swizzle, extra, offset; offset = cluster_id & ((1 << log_swizzle_size) - 1); extra = cluster_id >> log_swizzle_size; - + divmod_cluster_blk_major(cluster_idx_minor_div_swizzle, cluster_idx_major, extra); cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; - auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + + auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + cluster_minor_offset); - auto major_work_idx = static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + + auto major_work_idx = static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + cluster_major_offset); if (raster_order == RasterOrder::AlongN) { return {minor_work_idx, major_work_idx}; } else { - return {major_work_idx, minor_work_idx}; + return {major_work_idx, minor_work_idx}; } } @@ -331,13 +344,14 @@ public: // The basic tile scheduler does not require any additional workspace template static int - get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t) { + get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1) { return 0; } template static cutlass::Status - initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, uint32_t) { + initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, + uint32_t, const uint32_t = 1) { return Status::kSuccess; } @@ -353,8 +367,61 @@ public: CUTLASS_HOST_DEVICE static uint32_t get_work_k_tile_start(WorkTileInfo const&) { - // All work units returned by this scheduler start from K tile 0 - return 0u; + // All work units returned by this scheduler start from K tile 0 + return 0u; + } + + CUTLASS_DEVICE + static bool + need_separate_reduction(Params const& params) { + return false; + } + + CUTLASS_DEVICE + bool + is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { + return false; + } + + CUTLASS_DEVICE + uint32_t + epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const { + return 0; + } + + template + CUTLASS_DEVICE + void + separate_reduction( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + // Shares the accumulator set with peers in the global workspace + template + CUTLASS_DEVICE + static void + share( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + CUTLASS_DEVICE + static bool + valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return true; + } + + CUTLASS_DEVICE + static bool + requires_separate_reduction(Params const& params) { + return false; } }; diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp new file mode 100644 index 00000000..ba201fa9 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -0,0 +1,431 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/gemm_coord.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/cluster_sm90.hpp" + +namespace cutlass::gemm::kernel::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Persistent Thread Block (TB) scheduler +template +class PersistentTileSchedulerSm90Group { + // + // Data members + // + +private: + uint64_t current_work_linear_idx_ = 0; + uint64_t total_grid_size_ = 0; + + // Tracking current group, its starting linear idx and total tiles + struct GroupInfo { + uint64_t group = 0; + uint64_t start_linear_idx = 0; + uint64_t total_tiles = 0; + } current_group_info_; + +public: + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t L_idx = 0; + bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool + is_valid() const { + return is_valid_tile; + } + + CUTLASS_HOST_DEVICE + static WorkTileInfo + invalid_work_tile() { + return {-1, -1, -1, false}; + } + + CUTLASS_HOST_DEVICE + bool + is_final_split(uint32_t k_tiles_per_output_tile) const { + return true; + } + + CUTLASS_HOST_DEVICE + int32_t + reduction_subtile_idx() const { + return -1; + } + }; + + using ProblemShape = typename GroupProblemShape::UnderlyingProblemShape; + using Params = PersistentTileSchedulerSm90GroupParams; + using RasterOrder = typename Params::RasterOrder; + using RasterOrderOptions = typename Params::RasterOrderOptions; + struct Arguments { + int max_swizzle_size = 1; + // Not applying Heuristics for Grouped problems, since largest dimension can change per group + RasterOrderOptions raster_order = RasterOrderOptions::AlongM; + }; + + // Sink scheduler params as a member + Params scheduler_params; + + // + // Methods + // + + template + static Params + to_underlying_arguments( + GroupProblemShape problem_shapes, + TileShape tile_shape, + ClusterShape cluster_shape, + [[maybe_unused]] KernelHardwareInfo const& hw_info, + Arguments const& arguments, + [[maybe_unused]] void* workspace=nullptr, + [[maybe_unused]] const uint32_t epilogue_subtile = 1) { + + // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic + static_assert(cute::is_static::value); + static_assert(cute::is_static::value); + + dim3 problem_blocks = get_tiled_cta_shape_mnl( + problem_shapes.groups(), + reinterpret_cast(problem_shapes.host_problem_shapes), + tile_shape, cluster_shape); + + Params params; + params.initialize( + problem_blocks, + problem_shapes.groups(), + reinterpret_cast(problem_shapes.problem_shapes), + to_gemm_coord(tile_shape), + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order + ); + + return params; + } + + CUTLASS_HOST_DEVICE + static bool + can_implement(Arguments const& args) { + return true; + } + + PersistentTileSchedulerSm90Group() = default; + + CUTLASS_DEVICE explicit PersistentTileSchedulerSm90Group(Params const& params_) : scheduler_params(params_) { + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) + if (params_.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = uint64_t(blockIdx.x) + uint64_t(blockIdx.y) * uint64_t(gridDim.x); + } + else { + current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); + } + + total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); + + auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(params_.problem_shapes_[0]), params_.cta_shape_.m())); + auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(params_.problem_shapes_[0]), params_.cta_shape_.n())); + current_group_info_.total_tiles = cta_m * cta_n; +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work() { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work_for_linear_idx(uint64_t linear_idx) { + if (linear_idx >= scheduler_params.blocks_per_problem_) { + return WorkTileInfo::invalid_work_tile(); + } + + uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(linear_idx); + + auto [work_idx_m, work_idx_n, new_group_info, valid_tile] = get_work_idx_m_and_n(blk_per_grid_dim, + current_group_info_, + scheduler_params.groups_, + scheduler_params.problem_shapes_, + scheduler_params.cta_shape_, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); + + current_group_info_ = new_group_info; + return {work_idx_m, work_idx_n, static_cast(current_group_info_.group), valid_tile}; + } + + CUTLASS_DEVICE + void + advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); + } + + // get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle + static CUTLASS_DEVICE + cute::tuple + get_work_idx_m_and_n( + uint64_t blk_per_grid_dim, + struct GroupInfo group_info, + int32_t total_problem_groups, + ProblemShape* problem_shapes, + GemmCoord cta_shape, + FastDivmodU64Pow2 const& divmod_cluster_shape_major, + FastDivmodU64Pow2 const& divmod_cluster_shape_minor, + int32_t log_swizzle_size, + RasterOrder raster_order) { + + bool valid_tile = true; + int cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m())); + int cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n())); + + while (group_info.start_linear_idx + group_info.total_tiles <= blk_per_grid_dim) { + group_info.group++; + group_info.start_linear_idx += group_info.total_tiles; + cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group_info.group]), cta_shape.m())); + cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group_info.group]), cta_shape.n())); + group_info.total_tiles = cta_m * cta_n; + } + + uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; + divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim - group_info.start_linear_idx); + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); + if (raster_order == RasterOrder::AlongN) { + cluster_minor_offset = cta_m_in_cluster; + } + else { + cluster_minor_offset = cta_n_in_cluster; + } + + uint64_t cluster_idx_minor, cluster_idx_major; + + uint64_t cluster_idx_minor_div_swizzle, extra, offset; + + offset = cluster_id & ((1 << log_swizzle_size) - 1); + extra = cluster_id >> log_swizzle_size; + + uint64_t curr_group_cluster_blk_major, remainder; + divmod_cluster_shape_major(curr_group_cluster_blk_major, remainder, cta_m); + cluster_idx_minor_div_swizzle = extra / curr_group_cluster_blk_major; + cluster_idx_major = extra % curr_group_cluster_blk_major; + + cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; + + auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + + cluster_minor_offset); + auto major_work_idx = static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + + cluster_major_offset); + + if (raster_order == RasterOrder::AlongN) { + return {minor_work_idx, major_work_idx, group_info, valid_tile}; + } + else { + return {major_work_idx, minor_work_idx, group_info, valid_tile}; + } + + } + + // Given the inputs, computes the total number of output blocks this problem will compute over + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_tiled_cta_shape_mnl(int groups, ProblemShape const* problem_shapes, BlockShape cta_shape, ClusterShape cluster_shape) { + uint32_t total_ctas = 0; + uint32_t cta_in_N_dim = 1; // We linearize the blocks across all the problems here + for (int group = 0; group < groups; group++) { + auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shapes[group]), cute::shape<0>(cta_shape))); + auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shapes[group]), cute::shape<1>(cta_shape))); + total_ctas += cta_m * cta_n; + } + + return Params::get_tiled_cta_shape_mnl( + to_gemm_coord(cluster_shape), + total_ctas, cta_in_N_dim + ); + } + + // Given the inputs, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + GroupProblemShape problem_shapes, + BlockShape cta_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + Arguments arguments, + bool truncate_by_problem_size=true) { + + dim3 problem_blocks = get_tiled_cta_shape_mnl( + problem_shapes.groups(), + reinterpret_cast(problem_shapes.host_problem_shapes), + cta_shape, cluster_shape); + + return Params::get_grid_shape( + problem_blocks, + to_gemm_coord(cluster_shape), + hw_info, + arguments.max_swizzle_size, + arguments.raster_order, + /* truncate_by_problem_size = */true + ); + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the basic tile scheduler, this is always true. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&, Params const&) { + return true; + } + + // Performs the reduction across splits for a given output tile. Since this scheduler does + // not split output tiles, no reduction is needed. + template + CUTLASS_DEVICE + static void + fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} + + // Returns whether the current WorkTileInfo passed in should continue to be used. Since + // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo + // passed in should not be used after having been processed. + CUTLASS_DEVICE + static bool + continue_current_work(WorkTileInfo&) { + return false; + } + + // The basic tile scheduler does not require any additional workspace + template + static int + get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t, const uint32_t = 1) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, + uint32_t, const uint32_t = 1) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape_MNKL problem_shape, TileShape tile_shape) { + // All work units returned by this scheduler cover the entire K iteration + // space of the output tile assigned to the work unit. + return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); + } + + CUTLASS_HOST_DEVICE + static uint32_t + get_work_k_tile_start(WorkTileInfo const&) { + // All work units returned by this scheduler start from K tile 0 + return 0u; + } + + CUTLASS_DEVICE + static bool + need_separate_reduction(Params const& params) { + return false; + } + + CUTLASS_DEVICE + bool + is_work_tile_for_reduction(WorkTileInfo const& work_tile_info, Params const& params) { + return false; + } + + CUTLASS_DEVICE + uint32_t + epilgoue_subtile_idx(WorkTileInfo const& work_tile_info, Params const& params) const { + return 0; + } + + template + CUTLASS_DEVICE + void + separate_reduction( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + // Shares the accumulator set with peers in the global workspace + template + CUTLASS_DEVICE + static void + share( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + } + + CUTLASS_DEVICE + static bool + valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return true; + } + + CUTLASS_DEVICE + static bool + requires_separate_reduction(Params const& params) { + return false; + } +}; + +} // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index 584aa58e..81ff358a 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -69,6 +69,7 @@ public: using Params = PersistentTileSchedulerSm90StreamKParams; using ReductionMode = Params::ReductionMode; + using DecompositionMode = Params::DecompositionMode; struct WorkTileInfo { int32_t M_idx = 0; @@ -83,11 +84,43 @@ public: // Number of k tiles remaining for the work unit as a whole uint32_t k_tile_remaining = 0; + // Whether this unit of work is the final split for the given tile + bool is_separate_reduction = false; + CUTLASS_HOST_DEVICE bool is_valid() const { - // Use negative indices to denote invalid work - return M_idx >= 0; + // A work tile that computes no K tiles is invalid unless it is a separate-reduction work tile + // (which only performs reduction and epilogue) + return k_tile_count > 0 || is_separate_reduction; + } + + CUTLASS_HOST_DEVICE + bool + is_reduction_unit() const { + return is_separate_reduction; + } + + CUTLASS_HOST_DEVICE + int32_t + reduction_subtile_idx() const { + // For separate reduction units, the K_idx of the work tile is unused. + // Therefore, we override it to contain the subtile of that the reduction + // unit operates on. + return is_reduction_unit() ? K_idx : -1; + } + + CUTLASS_HOST_DEVICE + void + setup_separate_reduction(int32_t epilogue_subtile_idx) { + // Set the epilogue subtile in the K_idx, since this is otherwise unused + // by separate reduction units. + K_idx = epilogue_subtile_idx; + + is_separate_reduction = true; + k_tile_count = 0; + // Clean up remaining k tiles + k_tile_remaining = 0; } CUTLASS_HOST_DEVICE @@ -113,7 +146,10 @@ public: Arguments& operator=(Arguments const& args) { splits = args.splits; + max_swizzle_size = args.max_swizzle_size; raster_order = args.raster_order; + reduction_mode = args.reduction_mode; + decomposition_mode = args.decomposition_mode; return *this; } @@ -121,7 +157,10 @@ public: Arguments& operator=(Arguments&& args) noexcept { splits = args.splits; + max_swizzle_size = args.max_swizzle_size; raster_order = args.raster_order; + reduction_mode = args.reduction_mode; + decomposition_mode = args.decomposition_mode; return *this; } @@ -129,18 +168,20 @@ public: Arguments(int splits_) : splits(splits_) {} CUTLASS_HOST_DEVICE - Arguments(int splits_, int max_swizzle_size_, RasterOrderOptions raster_order_) : + Arguments(int splits_, int max_swizzle_size_, RasterOrderOptions raster_order_, DecompositionMode decomposition_mode_) : splits(splits_), max_swizzle_size(max_swizzle_size_), - raster_order(raster_order_) {} + raster_order(raster_order_), + decomposition_mode(decomposition_mode_) {} // The splitting factor to be used in a split-K decomposition of the problem. // If this is set to a value greater than 1, stream-K decomposition logic // is bypassed in favor of a split-K decomposition. int splits = 1; - const int max_swizzle_size = 1; + int max_swizzle_size = 1; RasterOrderOptions raster_order = RasterOrderOptions::Heuristic; ReductionMode reduction_mode = ReductionMode::Deterministic; + DecompositionMode decomposition_mode = DecompositionMode::Heuristic; }; // Sink scheduler params as a member @@ -158,7 +199,8 @@ public: ClusterShape cluster_shape, KernelHardwareInfo const& hw_info, Arguments const& args, - void* workspace) { + void* workspace, + const uint32_t epilogue_subtile = 1) { static_assert(cute::is_static::value); static_assert(cute::is_static::value); @@ -177,11 +219,22 @@ public: args.max_swizzle_size, args.raster_order, args.reduction_mode, - workspace + args.decomposition_mode, + workspace, + epilogue_subtile ); return params; } + CUTLASS_HOST_DEVICE + static bool + can_implement(Arguments const& args) { + // Split count > 1 is only valid for heuristic and split-K decomposition modes + return (args.splits == 1 || + args.decomposition_mode == DecompositionMode::Heuristic || + args.decomposition_mode == DecompositionMode::SplitK); + } + CUTLASS_HOST_DEVICE PersistentTileSchedulerSm90StreamK() { }; @@ -210,7 +263,7 @@ public: // for the fact that we have splits_ peers per output tile, we multiply this // value by splits_. For stream-K, this multiplication ends up being a no-op // because splits_ is set to 1 for stream-K. - if (linear_idx >= params.units_per_problem_ * params.splits_) { + if(linear_idx >= (params.units_per_problem_ * params.splits_ + params.separate_reduction_units_)) { // Invalid work. Return an empty result. return WorkTileInfo::invalid_work_tile(); } @@ -231,8 +284,8 @@ public: current_work_linear_idx_, work_tile_info, scheduler_params); } - CUTLASS_DEVICE static - bool + CUTLASS_DEVICE + static bool continue_current_work_for_linear_idx( uint64_t linear_idx, WorkTileInfo& work_tile_info, @@ -243,9 +296,8 @@ public: if (work_tile_info.k_tile_remaining == 0) { return false; } - assign_work(params, linear_idx, work_tile_info); - return true; + return work_tile_info.is_valid(); } CUTLASS_DEVICE @@ -281,7 +333,7 @@ public: problem_blocks, to_gemm_coord(cluster_shape), hw_info, - arguments.max_swizzle_size, + arguments.max_swizzle_size, arguments.raster_order ); } @@ -290,8 +342,22 @@ public: CUTLASS_HOST_DEVICE static bool requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) { - // Fixup is not needed for data-parallel tiles - return work_tile_info.k_tile_count != params.divmod_tiles_per_output_tile_.divisor; + // Fixup is not needed for invalid or data-parallel tiles + return work_tile_info.is_valid() && work_tile_info.k_tile_count != params.divmod_tiles_per_output_tile_.divisor; + } + + CUTLASS_HOST_DEVICE + static bool + requires_separate_reduction(Params const& params) { + return params.requires_separate_reduction(); + } + + // When the work tile is not special for reduction, it's valid. Otherwise need to skip + // global loading that producer warpgroup do, also math computation that consumer warpgroup do. + CUTLASS_DEVICE + static bool + valid_warpgroup_in_work_tile(WorkTileInfo const& work_tile_info) { + return !work_tile_info.is_reduction_unit(); } // Performs the reduction across splits for a given output tile. @@ -304,7 +370,7 @@ public: FrgTensorC& accumulators, uint32_t num_barriers, uint32_t barrier_idx) { - static constexpr uint32_t Offset = 2; + static constexpr uint32_t Offset = static_cast(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); static constexpr uint32_t MaxNumNamedBarriers = 2; using BarrierManager = NamedBarrierManager; return fixup_helper( @@ -327,16 +393,27 @@ public: if (!requires_fixup(params, work_tile_info)) { return; } - auto tile_idx = output_tile_index(params, work_tile_info); // Index of the lock on which to wait auto lock_idx = (tile_idx * num_barriers) + barrier_idx; + auto reduction_tile_idx = tile_idx; + auto [first_peer_id, my_peer_id, last_peer_id] = tile_peer_range(params, tile_idx, static_cast(work_tile_info.K_idx)); + auto reduction_peer_offset = 0; + if (params.requires_separate_reduction()) { + // If separate reduction is to be performed, each stream-K unit writes its partials + // to a separate portion of the workspace. There are as many of these portions as there + // are peers for a given output tile, so we multiply the tile index by the maximum peer count. + reduction_tile_idx *= Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); + reduction_peer_offset = my_peer_id * cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}); + } + // Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood. // Thus, the start of the reduction space is the same across all threads in a warp group. int reduction_offset = - (cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * tile_idx) + + (cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * reduction_tile_idx) + + reduction_peer_offset + (size(accumulators) * barrier_idx * BarrierManager::ThreadCount); ElementAccumulator* group_reduction_workspace = reinterpret_cast(params.reduction_workspace_) + reduction_offset; @@ -344,56 +421,86 @@ public: using AccumulatorArrayT = Array; using BlockStripedReduceT = BlockStripedReduce; + AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); + AccumulatorArrayT* accumulator_array = reinterpret_cast(&accumulators); + + int barrier_group_thread_idx = threadIdx.x % BarrierManager::ThreadCount; + // The number of tiles for which reduction is required is either: // (a) the total number of output tiles (in the case of split-K) - // (b) the number of stream-K tiles + // (b) the number of stream-K tiles (potentially multiplied by peer count if using separate reduction) // To calculate the total number of output tiles in the split-K case, we // note that, in the split-K case, the units_per_problem_ member of Params will be // the total number of output tiles. - auto reduction_tiles = params.splits_ > 1 ? params.units_per_problem_ : params.sk_tiles_; + uint32_t reduction_tiles = 0; + if (params.splits_ > 1) { + reduction_tiles = params.units_per_problem_; + } + else if (params.requires_separate_reduction()) { + reduction_tiles = params.sk_tiles_ * Params::max_peers_per_tile(params.sk_units_, params.sk_tiles_); + } + else { + reduction_tiles = params.sk_tiles_; + } + auto reduction_workspace_size = Params::get_reduction_workspace_size( reduction_tiles, to_gemm_coord(TileShape{}), sizeof_bits::value); BarrierType* lock_workspace = reinterpret_cast( reinterpret_cast(params.reduction_workspace_) + reduction_workspace_size); - AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); - AccumulatorArrayT* accumulator_array = reinterpret_cast(&accumulators); - int barrier_group_thread_idx = threadIdx.x % BarrierManager::ThreadCount; + if (work_tile_info.is_reduction_unit()) { + plus add_fragments; + auto peer_offset = size(accumulators) * num_barriers * BarrierManager::ThreadCount; - if (!work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor)) { - if (work_tile_info.K_idx == 0) { - // First peer initializes the workspace partials + // Wait until the peers collaborating on this output tile have all written + // their accumulators to workspace. + uint32_t num_peers = last_peer_id - first_peer_id + 1; + BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, num_peers); + + // Load the first peer's data + BlockStripedReduceT::load(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx); + + for (int i = 1; i < num_peers; ++i) { + // Load peer fragment + AccumulatorArrayT addend_fragment; + auto peer_reduction_workspace = reinterpret_cast(group_reduction_workspace + (i * peer_offset)); + + BlockStripedReduceT::load(addend_fragment, peer_reduction_workspace, barrier_group_thread_idx); + + // Add peer fragment + *accumulator_array = add_fragments(*accumulator_array, addend_fragment); + } + } + else if (!compute_epilogue(work_tile_info, params)) { + if (params.requires_separate_reduction() || work_tile_info.K_idx == 0) { + // The first peer initializes the workspace partials in the non-separate-reduction case, + // and all peers write to their own location in workspace when using separate reduction BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); } else { - if (params.reduction_mode_ == ReductionMode::Deterministic) { - // Wait until the preceding split added its accumulators - BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); - } - else { - // Wait until the first split has stored its accumulators. Note that the first split will have - // accumulated a value into the lock potentially greater than one (since the locked value is - // incremented by work_tile_info.k_tile_count below for both the deterministic and non-deterministic) - // cases. For non-deterministic reductions, all that non-first or last splits care about is whether - // the first split has been written, so we only wait while the locked value is less than 1. This - // avoids having to add logic to determine the work_tile_info.k_tile_count for the first split. - BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1); - } + // Wait until the preceding split added its accumulators + BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); // Perform reduction in workspace BlockStripedReduceT::reduce(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); } + // If separate reduction is being performed, each participating stream-K unit increments the barrier + // by only 1. Otherwise, increment by the K tile count that this unit has processed. + int32_t increment = params.requires_separate_reduction() ? 1 : work_tile_info.k_tile_count; + // Signal our arrival - BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.k_tile_count); + BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment); } else { - // Wait until the preceding split added its accumulators. - // For both the deterministic and non-deterministic case, each preceding split will have incremented - // the locked value by work_tile_info.k_tile_count. Thus, the final split konws that it can begin - // loading the partially-reduced value when the locked value reaches its starting K tile index (i.e., - // work_tile_info.K_idx). - BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + if (params.reduction_mode_ == ReductionMode::Deterministic) { + // Wait until the preceding split added its accumulators + BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + } + else { + // Wait unitl the first split has stored its accumulators + BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1); + } // The block computing the final split for the tile adds previously-reduced partials // to its accumulators and computes the epilogue. @@ -406,7 +513,13 @@ public: CUTLASS_HOST_DEVICE static bool compute_epilogue(WorkTileInfo const& work_tile_info, Params const& params) { - return work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor); + // `is_final_split` will be set to `true` for the following scenarios, all of which must compute the epilogue: + // 1. The tile is computed in data-parallel mode + // 2. The tile is computed in split-/stream-K mode and this work unit represents the final split of the tile + // 3. The tile is computed in split-/stream-K mode and separate reduction is used, and this is a separate reduction unit + return work_tile_info.is_valid() && + (work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor) && + !params.requires_separate_reduction()) || work_tile_info.is_separate_reduction; } // Returns the linearized index of the output tile corresponding to the tile with offset [L, M, K] @@ -432,7 +545,8 @@ public: Arguments const& args, ProblemShape problem_shape, KernelHardwareInfo const& hw_info, - uint32_t mma_warp_groups) { + uint32_t mma_warp_groups, + const uint32_t epilogue_subtile = 1) { auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); @@ -451,9 +565,11 @@ public: args.splits, args.max_swizzle_size, args.raster_order, + args.decomposition_mode, mma_warp_groups, sizeof_bits::value, - sizeof_bits::value + sizeof_bits::value, + epilogue_subtile ); } @@ -464,8 +580,9 @@ public: void* workspace, cudaStream_t stream, ProblemShape const& problem_shape, - KernelHardwareInfo const& hw_info, - uint32_t mma_warp_groups) { + KernelHardwareInfo const& hw_info, + uint32_t mma_warp_groups, + const uint32_t epilogue_subtile = 1) { auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); @@ -486,9 +603,11 @@ public: args.splits, args.max_swizzle_size, args.raster_order, + args.decomposition_mode, mma_warp_groups, sizeof_bits::value, - sizeof_bits::value + sizeof_bits::value, + epilogue_subtile ); } @@ -505,6 +624,7 @@ public: return work_tile_info.K_idx; } +private: // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining // iterations) is used to find the next tile in the current work unit. @@ -515,10 +635,22 @@ public: uint64_t linear_idx, WorkTileInfo& work_tile_info) { - uint64_t true_tile_id = linear_idx; - if (linear_idx >= params.sk_units_ && params.splits_ == 1) { + uint64_t output_tile_id = linear_idx; + if (linear_idx >= params.units_per_problem_ * params.splits_) { + // Separate-reduction work + auto cluster_size = params.get_cluster_size(); + // Divide up the linearized separate reduction units into clusters + auto cluster_linear_reduction_unit_idx = params.div_cluster_size((linear_idx - params.units_per_problem_)); + uint64_t cluster_tile_idx, epi_subtile_idx; + params.divmod_epilogue_subtile_(cluster_tile_idx, epi_subtile_idx, cluster_linear_reduction_unit_idx); + // Bring the linearized tile ID back into the space of tiles, rather than clusters + output_tile_id = cluster_tile_idx * cluster_size; + + work_tile_info.setup_separate_reduction(epi_subtile_idx); + } + else if (linear_idx >= params.sk_units_ && params.splits_ == 1) { // Data-parallel work - true_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; + output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; work_tile_info.K_idx = 0; work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; @@ -540,48 +672,114 @@ public: // To do so, we divide up the linearized stream-K units into clusters and share the same K // offsets for work within clusters. - // Equivalent to linear_idx / cluster_size - auto cluster_linear_work_idx = params.divmod_cluster_shape_minor_.divide( - params.divmod_cluster_shape_major_.divide(linear_idx) - ); + auto cluster_linear_work_idx = params.div_cluster_size(linear_idx); + + uint64_t group_idx; + params.divmod_sk_groups_(cluster_linear_work_idx, group_idx, cluster_linear_work_idx); + + // Determine whether we are in a "big group" that will process an additional + // stream-K cluster tile. + auto sk_cluster_tiles = params.div_cluster_size(params.sk_tiles_); + auto sk_cluster_tiles_in_group = params.divmod_sk_groups_.divide(sk_cluster_tiles); + if (group_idx < params.big_groups_) { + ++sk_cluster_tiles_in_group; + } + + // Determine whether we are in a "big unit" within the group, that will process + // an additional K chunk in the group. + auto sk_tiles_in_group = sk_cluster_tiles_in_group * params.get_cluster_size(); + auto k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor; + auto k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group); + auto big_units_in_group = params.div_cluster_size( + k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor)); uint64_t split; params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); - auto big_unit_cmp = params.splits_ > 1 ? split : cluster_linear_work_idx; - auto linear_idx_mult = params.splits_ > 1 ? params.divmod_tiles_per_output_tile_.divisor : params.k_tiles_per_sk_unit_; + + bool is_split_k = params.splits_ > 1; + auto big_unit_cmp_lhs = is_split_k ? split : cluster_linear_work_idx; + auto big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; + auto linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; + auto k_tiles_per_split = is_split_k ? params.k_tiles_per_sk_unit_ : k_tiles_per_unit_in_group; // Determine the starting k iteration computed by this stream-K work unit - uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + (params.k_tiles_per_sk_unit_ * split); + uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + + (k_tiles_per_split * split); // Adjust the starting position and number of k iterations for "big units," which - // compute one extra iteration. These are the first big_units_ units in the - // linearized ID space. - bool is_big_unit = big_unit_cmp < params.big_units_; - if (is_big_unit) { + // compute one extra iteration. If there are any big units, they will be the first + // in the linearized ID space. + auto k_tiles_in_my_split = k_tiles_per_split; + if (big_unit_cmp_lhs < big_unit_cmp_rhs) { // Since the "big units" are the first units in the linearized ID space, each // of the units preceding this big unit computed one extra iteration. Thus, // we must offset our start iteration by the number of units that precede // the current unit in the linearized ID space. - unit_iter_start += big_unit_cmp; + unit_iter_start += big_unit_cmp_lhs; + ++k_tiles_in_my_split; } else { // Increment by one for each of the big clusters (since all big units precede this unit) - unit_iter_start += params.big_units_; + unit_iter_start += big_unit_cmp_rhs; + } + + if (!is_split_k) { + // Adjust the unit starting position and number of tiles to avoid + // computing splits of size less than min_iters_per_sk_unit_ + int unused, start_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start); + if (start_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another + // stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take over these K tiles. + unit_iter_start -= start_tile_k_tile; + k_tiles_in_my_split += start_tile_k_tile; + } + else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. + auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile); + unit_iter_start += adjustment_tiles; + k_tiles_in_my_split -= adjustment_tiles; + } } if (work_tile_info.k_tile_count == 0) { // This is a new unit - work_tile_info.k_tile_remaining = params.k_tiles_per_sk_unit_; - // Only adjust iteration count for big unit if we are initializing this - // work unit. For existing work units, the extra iteration for big units - // has already been accounted for in k_tiles_reamaining - if (is_big_unit) { - ++work_tile_info.k_tile_remaining; + if (!is_split_k) { + // + // Adjust the unit ending position and number of tiles to avoid + // computing splits of size less than min_iters_per_sk_unit_ + // + + // Begin by assuming that no adjustment is needed + auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split; + + int unused, end_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end); + + if (end_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. + k_tiles_in_my_split -= end_tile_k_tile; + } + else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take on these K tiles. + k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); + } } + + work_tile_info.k_tile_remaining = k_tiles_in_my_split; } - // Find the output tile corresponding to the final k iteration covered by this + uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; + + // Find the output tile corresponding to the final k tile covered by this // work unit. Stream-K work units will work backwards in terms of the tiles they // are responsible computing. This is beneficial because the final (partial) // tile computed by a stream-K block is typically the beginning of the output @@ -590,43 +788,45 @@ public: // other work units computing portions of that output tile, it is preferable // for them to be computed later, so as to reduce the likelihood of blocking // on other work. - uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; - true_tile_id = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); - uint32_t true_tile_iter_start = true_tile_id * params.divmod_tiles_per_output_tile_.divisor; - uint32_t true_tile_iter_end = true_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; + auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); + uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor; + uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; + + // Convert the output tile from the linearized space within each group to the + // overall linearized space. + output_tile_id = (output_tile_id_in_group * params.divmod_sk_groups_.divisor) + group_idx; // Bring the linearized tile ID back into the space of tiles, rather than clusters - true_tile_id *= params.divmod_cluster_shape_major_.divisor * params.divmod_cluster_shape_minor_.divisor; + output_tile_id *= params.get_cluster_size(); auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); // The final linearized tile ID is in units of the cluster dimension over which we rasterize. if (params.raster_order_ == RasterOrder::AlongN) { - true_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; + output_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; } else { - true_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; + output_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; } // The unit's starting k iteration in the current tile is either the starting // iteration for the tile as a whole, or the starting k iteration for the unit // as a whole (if the latter is greater than the former). - uint32_t tile_iter_start = max(true_tile_iter_start, unit_iter_start); + uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start); // Similarly, the unit's ending k iteration (exclusive) is either the end of // the current tile it is assigned, or the ending iteration of the unit as a whole // (if the latter is less than the former). - uint32_t tile_iter_end = min(true_tile_iter_end, unit_iter_end + 1); + uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1); // Set the k offset to be the starting k tile for this output tile - work_tile_info.K_idx = static_cast(tile_iter_start - true_tile_iter_start); - + work_tile_info.K_idx = static_cast(tile_iter_start - output_tile_iter_start); work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; } uint64_t work_idx_l, remainder; - params.divmod_batch_(work_idx_l, remainder, true_tile_id); + params.divmod_batch_(work_idx_l, remainder, output_tile_id); uint64_t cta_per_grid_dim = params.divmod_cluster_shape_minor_.divide(remainder); @@ -642,7 +842,57 @@ public: work_tile_info.M_idx = work_idx_m; work_tile_info.N_idx = work_idx_n; work_tile_info.L_idx = static_cast(work_idx_l); + } + // Returns the starting and ending peer ID of this tile + CUTLASS_HOST_DEVICE + static auto + tile_peer_range(Params const& params, uint32_t tile_idx, uint32_t cur_k_tile) { + auto tile_idx_in_cluster_path = params.div_cluster_size(tile_idx); + auto start_k_tile = params.divmod_tiles_per_output_tile_.divisor * tile_idx_in_cluster_path; + auto end_k_tile = start_k_tile + params.divmod_tiles_per_output_tile_.divisor - 1; + auto big_unit_k_tiles = params.big_units_ * (params.k_tiles_per_sk_unit_ + 1); + + auto adjust_unit = [&](uint32_t k_tile, uint32_t unit_idx, uint32_t k_tiles_per_unit) { + auto unit_k_start = unit_idx * k_tiles_per_unit; + auto unit_k_end = unit_k_start + k_tiles_per_unit; + if (k_tile - start_k_tile < Params::min_iters_per_sk_unit_ && + unit_k_end - start_k_tile < Params::min_iters_per_sk_unit_) { + // k_tile is within the first min_iters_per_sk_unit_ K tiles of this output tile, + // and the stream-K unit computes fewer than min_iters_per_sk_unit_ K tiles for this + // output tile. This work will thus be subsumed by the next stream-K unit. + ++unit_idx; + } + + if (end_k_tile + 1 - k_tile < Params::min_iters_per_sk_unit_ && + end_k_tile + 1 - unit_k_start < Params::min_iters_per_sk_unit_) { + // k_tile is within the last min_iters_per_sk_unit_ K tiles of this output tile, + // and the stream-K unit computes fewer than min_iters_per_sk_unit_ K tiles for this + // output tile. This work will thus be subsumed by the previous stream-K unit. + --unit_idx; + } + + return unit_idx; + }; + + // Lambda to find the ID of the stream-K unit that computes this K tile + auto find_unit = [&](uint32_t k_tile) { + if (k_tile < big_unit_k_tiles) { + // The tile is within the "big unit range" + auto k_tiles_per_unit = params.k_tiles_per_sk_unit_ + 1; + auto unit_idx = k_tile / k_tiles_per_unit; + return static_cast(adjust_unit(k_tile, unit_idx, k_tiles_per_unit)); + } + else { + // The tile is after the "big unit range." Account for this by finding the "normal unit" + // that it belongs to, and then offsetting by the number of big units + auto k_tiles_per_unit = params.k_tiles_per_sk_unit_; + auto unit_idx = ((k_tile - big_unit_k_tiles) / params.k_tiles_per_sk_unit_) + (params.big_units_); + return static_cast(adjust_unit(k_tile, unit_idx, k_tiles_per_unit)); + } + }; + + return cute::make_tuple(find_unit(start_k_tile), find_unit(cur_k_tile), find_unit(end_k_tile)); } }; diff --git a/include/cutlass/gemm/kernel/tile_scheduler.hpp b/include/cutlass/gemm/kernel/tile_scheduler.hpp index a81460e4..a46723a1 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -38,6 +38,7 @@ #include "cutlass/detail/dependent_false.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm { @@ -52,6 +53,8 @@ struct PersistentScheduler { }; struct StreamKScheduler { }; +struct GroupScheduler { }; // Only used for Grouped GEMMs + //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm @@ -69,6 +72,7 @@ template < class ArchTag, class TileShape, class ClusterShape + , class ProblemShapeType = void > struct TileSchedulerSelector { static_assert(cutlass::detail::dependent_false, @@ -122,6 +126,21 @@ struct TileSchedulerSelector< using Scheduler = PersistentTileSchedulerSm90StreamK; }; +template < + class TileShape, + class ClusterShape + , class GroupProblemShape +> +struct TileSchedulerSelector< + GroupScheduler, + arch::Sm90, + TileShape, + ClusterShape + , GroupProblemShape + > { + using Scheduler = PersistentTileSchedulerSm90Group; +}; + //////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 8cfb4845..be1251ca 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -218,7 +218,7 @@ struct PersistentTileSchedulerSm90Params { auto possibly_truncate = [&](int x, int y) { if (truncate_by_problem_size) { - return cutlass::platform::min(x, y); + return platform::min(x, y); } else { return x; @@ -272,7 +272,7 @@ struct PersistentTileSchedulerSm90Params { CUTLASS_HOST_DEVICE static int32_t get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { - int min_cta_dim = cutlass::platform::min(problem_ctas_m, problem_ctas_n); + int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n); if (max_swizzle_size >= 8 && min_cta_dim >= 6) { return 3; } @@ -370,6 +370,18 @@ struct PersistentTileSchedulerSm90StreamKParams { Nondeterministic }; + // Strategies for decomposing the problem + enum class DecompositionMode { + // Use a heuristic to determine whether data-parallel, split-K, or stream-K decomposition should be performed + Heuristic, + // Force a data-parallel decomposition + DataParallel, + // Force a split-K decomposition. This should be paired with setting the `splits` parameter + SplitK, + // Force a stream-K decomposition + StreamK + }; + using UnderlyingParams = PersistentTileSchedulerSm90Params; using RasterOrder = UnderlyingParams::RasterOrder; using RasterOrderOptions = UnderlyingParams::RasterOrderOptions; @@ -387,6 +399,17 @@ struct PersistentTileSchedulerSm90StreamKParams { // and may be overridden in other decompositions. FastDivmodU64 divmod_clusters_mnl_{}; + // We divide up the number of stream-K tiles amongst G groups of stream-K units. + // The stream-K units within a group collaborate to comptue over the `sk_tiles / G` + // tiles assigned to that group. Non-unit group sizes can help to preserve L2 locality of + // partial chunks computed by stream-K units -- units 0 in each group will compute identical K extents + // of tiles that would be assigned in the same wave according to the rasterization order of the + // data-parallel formulation of the problem. + FastDivmodU64 divmod_sk_groups_{}; + + // Number of stream-K units in each group + FastDivmodU64 divmod_sk_units_per_group_{}; + uint64_t units_per_problem_ = 0; FastDivmod divmod_tiles_per_output_tile_{}; int32_t log_swizzle_size_ = 0; @@ -403,6 +426,9 @@ struct PersistentTileSchedulerSm90StreamKParams { // at the granularity of a cluster, we store only the number of big clusters. uint32_t big_units_ = 0; + // The number of groups of stream-K units that will process an extra stream-K tile cluster. + uint32_t big_groups_ = 0; + // Workspace for holding partial accumulators to be reduced across stream-K/split-K units void* reduction_workspace_ = nullptr; @@ -419,8 +445,53 @@ struct PersistentTileSchedulerSm90StreamKParams { // Strategy to use when reducing between collaborating CTAs ReductionMode reduction_mode_ = ReductionMode::Deterministic; - // Minimum number of tiled k that can be assigned to a stream-K unit - static constexpr uint32_t min_iters_per_sk_unit_ = 4u; + // The number of sub blocks in the kernel epilogue + FastDivmodU64 divmod_epilogue_subtile_{}; + + // The number of blocks that launched for doing separate reduction + uint32_t separate_reduction_units_ = 0; + + // Minimum number of k tiles that can be assigned to a stream-K unit + static constexpr uint32_t min_iters_per_sk_unit_ = 8u; + + // Maximum number of groups of stream-K units + static constexpr uint32_t max_sk_groups_ = 8u; + + // Divides dividend by the cluster size + CUTLASS_HOST_DEVICE + uint64_t + div_cluster_size(uint64_t dividend) const { + // Use each underlying fast divmod rather than performing integer division + // by the multiplication of major.divisor * minor.divisor + return divmod_cluster_shape_minor_.divide( + divmod_cluster_shape_major_.divide(dividend) + ); + } + + CUTLASS_HOST_DEVICE + uint64_t + get_cluster_size() const { + return divmod_cluster_shape_minor_.divisor * divmod_cluster_shape_major_.divisor; + } + + // Returns whether the kernel uses separate reduction + CUTLASS_HOST_DEVICE + bool + requires_separate_reduction() const { + return separate_reduction_units_ > 0; + } + + // Returns the maximum number of peers that can collaborate on a given output tile + CUTLASS_HOST_DEVICE + static uint32_t + max_peers_per_tile(uint64_t sk_units, uint64_t sk_tiles) { + // When we can divide up our SK units to SK tiles evenly, the number of peers + // per SK tile is exactly (sk_units_ / sk_tiles_). In cases where this division + // is not exact, some tiles will need to be covered by additional SK units. Because + // the extra work can occur at both the beginning and the end of the SK tile, at + // most 2 extra peers will be needed. + return static_cast(sk_units / sk_tiles + 2); + } // Initializes members. This variant of the method should only be used when // problem_shape and tile_shape contain modes of only rank 1. @@ -434,7 +505,9 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, ReductionMode reduction_mode, - void* workspace + DecompositionMode decomposition_mode, + void* workspace, + const uint32_t epilogue_subtile = 1 ) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl( problem_shape, tile_shape, cluster_shape); @@ -451,7 +524,9 @@ struct PersistentTileSchedulerSm90StreamKParams { max_swizzle, raster_order_option, reduction_mode, - workspace + decomposition_mode, + workspace, + epilogue_subtile ); } @@ -468,7 +543,9 @@ struct PersistentTileSchedulerSm90StreamKParams { int max_swizzle, RasterOrderOptions raster_order_option, ReductionMode reduction_mode, - void* workspace + DecompositionMode decomposition_mode, + void* workspace, + const uint32_t epilogue_subtile = 1 ) { UnderlyingParams underlying_params; underlying_params.initialize( @@ -488,7 +565,8 @@ struct PersistentTileSchedulerSm90StreamKParams { // Reduction workspace is at the beginning of the workspace. Lock workspace follows. void* reduction_workspace = workspace; - if (splits > 1) { + if (decomposition_mode == DecompositionMode::SplitK || + (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { // Short circuit to basic split-K decomposition // Don't split by more than the available number of SMs @@ -531,24 +609,9 @@ struct PersistentTileSchedulerSm90StreamKParams { uint64_t ctas_per_wave = grid.x * grid.y; // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. - uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile); + uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, k_tiles_per_output_tile, decomposition_mode); uint64_t dp_tiles = output_tiles - sk_tiles; - if (sk_tiles == 0) { - // Short circuit to basic data-parallel decomposition - set_params_basic( - underlying_params, - problem_blocks_m, - problem_blocks_n, - problem_blocks_l, - /* splits = */ 1, - k_tiles_per_output_tile, - reduction_workspace, - reduction_mode - ); - return; - } - // Calculate the number of work units covering the data-parallel and stream-K tiles. // A "work unit" is a single index in the linearized ID space used by the scheduler. // We distinguish it from a "block," which is typically tied to a hardware unit @@ -576,12 +639,127 @@ struct PersistentTileSchedulerSm90StreamKParams { uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_); min_sized_sk_units = (min_sized_sk_units / cluster_size) * cluster_size; - uint64_t sk_units = cutlass::platform::min(ctas_per_wave, min_sized_sk_units); + uint64_t sk_units = platform::min(ctas_per_wave, min_sized_sk_units); - // If the number of stream-K units is a multiple of the number of stream-K tiles, then - // the problem can leverage a basic split-K decomposition for the stream-K tiles. - if (sk_tiles < sk_units && sk_units % sk_tiles == 0) { - // Short circuit to basic split-K decomposition + if (decomposition_mode == DecompositionMode::DataParallel || + (decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) || + sk_units == 0) { + // Short circuit to basic data-parallel decomposition + set_params_basic( + underlying_params, + problem_blocks_m, + problem_blocks_n, + problem_blocks_l, + /* splits = */ 1, + k_tiles_per_output_tile, + reduction_workspace, + reduction_mode + ); + return; + } + + bool do_separate_reduction = should_perform_separate_reduction( + epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave); + + // Determine the number of stream-K groups that will be used. We currently use + // max_sk_groups_ unless this extends beyond the extent of the dimension over + // which the problem is rasterized. For example, if the tiled problem shape + // (in CTA_M x CTA_N representation) when using 1x1 clusters is 4x16, + // and we rasterize along the M dimension, we choose 4 groups, rather than 8. + // If the cluster shape is 2x1, we choose 2 groups (CTA_M / CLUSTER_M). + uint32_t max_groups_problem; + if (underlying_params.raster_order_ == RasterOrder::AlongM) { + max_groups_problem = problem_blocks_m / cluster_shape.m(); + } + else { + max_groups_problem = problem_blocks_n / cluster_shape.n(); + } + + // Select the number of groups that will be use. We start with the maximum + // number of potential groups, and iterate down looking for a group size that + // evenly divides the stream-K units and tiles, and for which the resulting + // number of K tiles per stream-K unit remains above min_iters_per_sk_unit_ + + uint32_t groups = platform::min(max_groups_problem, uint32_t(max_sk_groups_)); + + // Grouping is disabled when separate reduction is used + if (do_separate_reduction) { + groups = 1; + } + + uint32_t fallback_groups = 0; + auto sk_cluster_tiles = sk_tiles / cluster_size; + auto sk_cluster_units = sk_units / cluster_size; + + auto sk_splits_too_small = [&](uint32_t g) { + // Check whether the number of K tiles computed per stream-K unit is less + // than min_iters_per_sk_unit_ + auto total_sk_k_tiles = (sk_tiles / g) * k_tiles_per_output_tile; + auto k_tiles_per_sk_unit = total_sk_k_tiles / (sk_units / g); + return k_tiles_per_sk_unit < min_iters_per_sk_unit_; + }; + + auto is_ideal_grouping = [&](uint32_t g) { + // An ideal grouping will evenly divide stream-K clusters, evenly divide + // stream-K tiles, and not result in stream-K splits that are too small. + return (sk_cluster_units % g == 0) && (sk_cluster_tiles % g == 0) && !sk_splits_too_small(g); + }; + + auto is_valid_grouping = [&](uint32_t g) { + // A grouping is valid, but not ideal, if it evenly divides the + // stream-K clusters and does not result in stream-K splits that are + // too small. Such a setting can be used as a fallback option in the + // case that an ideal grouping is not achievable + return sk_cluster_units % g == 0 && !sk_splits_too_small(g); + }; + + while (groups > 1 && !is_ideal_grouping(groups)) { + if (fallback_groups == 0 && is_valid_grouping(groups)) { + // Set fallback groups once in preference for a larger number of groups. + fallback_groups = groups; + } + --groups; + } + + // If groups == 1, we did not find a group count that satisfies all criteria. If we have + // found a fallback group count, use this instead. + if (groups == 1 && fallback_groups > 0) { + groups = fallback_groups; + } + + auto sk_units_per_group = sk_units / groups; + + // sk_tiles is guaranteed to be divisible by cluster_size because it is calculated as: + // sk_tiles = (waves <= 2) ? total_tiles : (sm_count + (total_tiles % sm_count)) + // Both total_tiles and sm_count are multiples of cluster size due to padding added + // prior to kernel launch. + uint64_t sk_clustered_tiles = sk_tiles / cluster_size; + uint64_t sk_clustered_tiles_per_group = sk_clustered_tiles / groups; + uint64_t sk_tiles_per_group = sk_clustered_tiles_per_group * cluster_size; + + // Groups that will process an extra stream-K tile cluster. These differ from "big_units," which + // are stream-K units within a group that process an extra K chunk. + uint64_t sk_big_groups = sk_clustered_tiles % groups; + + uint64_t k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; + + // Number of k tiles computed per stream-K unit + uint64_t k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group; + + uint32_t reduction_units = 0; + + // Use separate reduction when we have less than one wave of output tiles (dp_tiles == 0) + // and when each tile will be operated on by at least two stream-K units (sk_units > 2 * sk_tiles) + if (do_separate_reduction) { + // Each reduction unit will reduce the partials of an epilogue subtile for + // a given output tile and compute the epilogue. Thus, there are as many reduction + // units as there are epilogue subtiles. + reduction_units = sk_tiles * epilogue_subtile; + } + else if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) { + // If the number of stream-K units is a multiple of the number of stream-K tiles, then + // the problem can leverage a basic split-K decomposition for the stream-K tiles. + // This case happens when separate reduction is disable. uint32_t sk_splits = static_cast(sk_units / sk_tiles); set_params_basic( underlying_params, @@ -595,37 +773,13 @@ struct PersistentTileSchedulerSm90StreamKParams { ); return; } - - // Number of k iterations computed per stream-K units - uint64_t k_tiles_per_sk_unit = k_tiles_sk_total / sk_units; - - // Number of stream-K units that need to compute extra iterations in order to cover - // the residual k iterations. This assumes that each such unit computes one additional - // iteration. - uint64_t sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units); - - // The division below is guaranteed to be exact because sk_big_units is guaranteed - // to be a multiple of cluster_size. This is useful because - // it allows us to use a block's linearized cluster ID to determine whether it is - // a big block. The reasoning behind this guarnatee is explained as follows: - // sk_big_units = k_tiles_sk_total - (k_tiles_per_sk_unit * sk_units); - // - // - k_tiles_sk_total is a multiple of cluster_size because it is the product - // of number of tail tiles and the number of k iterations per tile. Because - // both the number of output tiles and number of available SMs are rounded - // to be multiples of cluster shape, the number of tail tiles - // (output_tiles % avail_sms) is a multpile of cluster_size. - // - // - sk_units is a multiple of cluster_size because it is either blocks_per_wave - // or 0, and blocks_per_wave is a multiple of the cluster_size due to the grid-planning - // logic rounding to multiples of cluster dimensions - uint64_t sk_big_units_per_cluster = sk_big_units / cluster_size; - divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; divmod_batch_ = underlying_params.divmod_batch_; divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; + divmod_sk_groups_ = FastDivmodU64(static_cast(groups)); + divmod_sk_units_per_group_ = FastDivmodU64(static_cast(sk_units / groups)); // Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units. // This setting ensures that the use of this divmod for stream-K decompositions @@ -635,12 +789,19 @@ struct PersistentTileSchedulerSm90StreamKParams { log_swizzle_size_ = underlying_params.log_swizzle_size_; units_per_problem_ = static_cast(dp_units + sk_units); raster_order_ = underlying_params.raster_order_; - big_units_ = static_cast(sk_big_units_per_cluster); + + // Assign big_units_ assuming that group count == 1. This is unused by stream-K + // when group count > 1. + big_units_ = static_cast(k_tiles_per_group % k_tiles_per_sk_unit); + + big_groups_ = static_cast(sk_big_groups); reduction_workspace_ = reduction_workspace; sk_tiles_ = sk_tiles; sk_units_ = static_cast(sk_units); k_tiles_per_sk_unit_ = static_cast(k_tiles_per_sk_unit); reduction_mode_ = reduction_mode; + divmod_epilogue_subtile_ = FastDivmodU64(epilogue_subtile); + separate_reduction_units_ = reduction_units; } // Given the inputs, computes the physical grid we should launch. @@ -696,23 +857,28 @@ struct PersistentTileSchedulerSm90StreamKParams { // Returns the number of stream-K tiles that will be computed amongst `output_tiles` total // output tiles on a device with `ctas_per_wave` CTAs in each wave. static uint32_t - get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile) { + get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave, uint32_t k_tiles_per_output_tile, DecompositionMode decomposition_mode) { uint32_t full_waves = static_cast(output_tiles / ctas_per_wave); uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); - if (full_waves == total_waves || k_tiles_per_output_tile <= min_iters_per_sk_unit_) { - // All tiles will be data-parallel tiles if there is either no quantization - // or if there is no work to be split. + if (decomposition_mode == DecompositionMode::DataParallel || + decomposition_mode == DecompositionMode::SplitK) { return 0; } - // - // The final wave is not full. Perform some stream-K work. - // + if (decomposition_mode == DecompositionMode::Heuristic) { + if (full_waves == total_waves || k_tiles_per_output_tile <= min_iters_per_sk_unit_) { + // All tiles will be data-parallel tiles if there is either no quantization + // or if there is no work to be split. + return 0; + } - // Rudimentary heuristic: prefer data-parallel decomposition if we have more than - // one wave and the tail wave is more than half full. This is subject to change. - if (full_waves != 0) { + // + // The final wave is not full. Perform some stream-K work. + // + + // Rudimentary heuristic: prefer data-parallel decomposition if we have more than + // one wave and the tail wave is more than half full. This is subject to change. uint64_t tail_tiles = output_tiles - (full_waves * ctas_per_wave); if (tail_tiles >= (ctas_per_wave / 2)) { return 0; @@ -729,6 +895,22 @@ struct PersistentTileSchedulerSm90StreamKParams { return static_cast(output_tiles - dp_tiles); } + CUTLASS_HOST_DEVICE + static uint64_t + get_num_sk_units(GemmCoord cluster_shape, uint64_t ctas_per_wave, uint32_t sk_tiles, uint32_t k_tiles_per_output_tile) { + // Number of k iterations computed by the stream-K units as a whole + uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles; + + // Calculate the number of stream-K units that would be needed if each stream-K unit + // computed the minimum allowable k iterations. Truncate this to be in units of clusters. + auto cluster_size = cluster_shape.m() * cluster_shape.n(); + uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_); + min_sized_sk_units = (min_sized_sk_units / cluster_size) * cluster_size; + + uint64_t sk_units = platform::min(ctas_per_wave, min_sized_sk_units); + return sk_units; + } + // Calculates the size of the workspace needed for holding reduction barriers CUTLASS_HOST_DEVICE static int @@ -759,9 +941,11 @@ struct PersistentTileSchedulerSm90StreamKParams { int splits, int max_swizzle, RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, - uint32_t accumulator_bits) { + uint32_t accumulator_bits, + uint32_t epilogue_subtile = 1) { auto log_swizzle_size = UnderlyingParams::get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle); problem_blocks.x = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); @@ -771,7 +955,12 @@ struct PersistentTileSchedulerSm90StreamKParams { // of output tiles that will be split, and then calculate the workspace needed to cover these. uint64_t output_tiles = problem_blocks.x * problem_blocks.y * problem_blocks.z; - if (splits > 1) { + if (decomposition_mode == DecompositionMode::DataParallel) { + barrier_workspace_size = 0; + reduction_workspace_size = 0; + } + else if (decomposition_mode == DecompositionMode::SplitK || + (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { // Basic split-K variant requires workspace for all output tiles barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups, barrier_bits); reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits); @@ -794,14 +983,39 @@ struct PersistentTileSchedulerSm90StreamKParams { raster_order_option ); uint64_t ctas_per_wave = grid.x * grid.y; - uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, static_cast(k_tiles_per_output_tile)); + uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave, static_cast(k_tiles_per_output_tile), decomposition_mode); + uint64_t sk_units = get_num_sk_units(cluster_shape, ctas_per_wave, sk_tiles, k_tiles_per_output_tile); + uint64_t dp_tiles = output_tiles - sk_tiles; + uint64_t reduction_tiles = sk_tiles; + if (should_perform_separate_reduction(epilogue_subtile, sk_units, sk_tiles, dp_tiles, ctas_per_wave)) { + // In separate reduction, each peer writes to its own location in scratch space. + // Thus, for separate reduction, we need as many reduction tiles per output tile + // as there are the maximum number of peers that can collaborate on an output tile. + reduction_tiles *= max_peers_per_tile(sk_units, sk_tiles); + } + + // Though separate reduction requires a larger reduction workspace, only one barrier + // is needed per output tile. Each peer will increment the barrier by one once the peer has + // written its accumulator to scratch space. The separate reduction unit will only begin + // performing the reduction when the barrier has reached the number of peers for the output tile. barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups, barrier_bits); - reduction_workspace_size = get_reduction_workspace_size(sk_tiles, tile_shape, accumulator_bits); + reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits); } } #endif // !defined(__CUDACC_RTC__) + // Returns whether the kernel is configured in a manner for which separate reduction should be used + CUTLASS_HOST_DEVICE + static bool + should_perform_separate_reduction(uint32_t epilogue_subtile, uint64_t sk_units, uint64_t sk_tiles, uint64_t dp_tiles, uint64_t ctas_per_wave) { + // We perform separate reduction if we have fewer than one wave of output tiles + // and each output tile is covered by at least to stream-K units. When sk_units is + // multiple of sk_tiles, will choose basic split-k path instead of separate reduction for now. + return (epilogue_subtile != 1) && (dp_tiles == 0) && (sk_units > 2u * sk_tiles) && + (sk_units + sk_tiles * epilogue_subtile <= ctas_per_wave); + } + // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when // problem_shape and tile_shape contain modes of only rank 1. static int @@ -813,9 +1027,11 @@ struct PersistentTileSchedulerSm90StreamKParams { int splits, int max_swizzle, RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, - uint32_t element_accumulator_bits) { + uint32_t element_accumulator_bits, + uint32_t epilogue_subtile) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); @@ -829,9 +1045,11 @@ struct PersistentTileSchedulerSm90StreamKParams { splits, max_swizzle, raster_order_option, + decomposition_mode, mma_warp_groups, barrier_bits, - element_accumulator_bits + element_accumulator_bits, + epilogue_subtile ); } @@ -848,9 +1066,11 @@ struct PersistentTileSchedulerSm90StreamKParams { int splits, int max_swizzle, RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, - uint32_t element_accumulator_bits) { + uint32_t element_accumulator_bits, + uint32_t epilogue_subtile = 1) { int barrier_workspace_size = 0; int reduction_workspace_size = 0; @@ -867,9 +1087,11 @@ struct PersistentTileSchedulerSm90StreamKParams { splits, max_swizzle, raster_order_option, + decomposition_mode, mma_warp_groups, barrier_bits, - element_accumulator_bits + element_accumulator_bits, + epilogue_subtile ); #endif @@ -889,9 +1111,11 @@ struct PersistentTileSchedulerSm90StreamKParams { int splits, int max_swizzle, RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, - uint32_t element_accumulator_bits) { + uint32_t element_accumulator_bits, + uint32_t epilogue_subtile) { dim3 problem_blocks = UnderlyingParams::get_tiled_cta_shape_mnl(problem_shape, tile_shape, cluster_shape); uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); @@ -907,9 +1131,11 @@ struct PersistentTileSchedulerSm90StreamKParams { splits, max_swizzle, raster_order_option, + decomposition_mode, mma_warp_groups, barrier_bits, - element_accumulator_bits + element_accumulator_bits, + epilogue_subtile ); } @@ -928,9 +1154,11 @@ struct PersistentTileSchedulerSm90StreamKParams { int splits, int max_swizzle, RasterOrderOptions raster_order_option, + DecompositionMode decomposition_mode, uint32_t mma_warp_groups, uint32_t barrier_bits, - uint32_t element_accumulator_bits) { + uint32_t element_accumulator_bits, + uint32_t epilogue_subtile = 1) { #if !defined(__CUDACC_RTC__) int barrier_workspace_size = 0; @@ -947,9 +1175,11 @@ struct PersistentTileSchedulerSm90StreamKParams { splits, max_swizzle, raster_order_option, + decomposition_mode, mma_warp_groups, barrier_bits, - element_accumulator_bits + element_accumulator_bits, + epilogue_subtile ); if (barrier_workspace_size > 0) { @@ -982,6 +1212,7 @@ struct PersistentTileSchedulerSm90StreamKParams { divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; divmod_batch_ = FastDivmodU64(blocks_m * blocks_n); divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); + divmod_sk_groups_ = FastDivmodU64(1u); auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * underlying_params.divmod_cluster_shape_minor_.divisor; divmod_clusters_mnl_ = FastDivmodU64((blocks_m * blocks_n * blocks_l) / cluster_size); splits_ = splits; @@ -997,9 +1228,11 @@ struct PersistentTileSchedulerSm90StreamKParams { // No stream-K work is performed for "basic" data-parallel and split-K decompositions sk_tiles_ = 0; sk_units_ = 0; + divmod_sk_units_per_group_ = FastDivmodU64(1u); + separate_reduction_units_ = 0; } -private: + private: // Round up number of bytes to the nearest multiple of L2 cache line alignment CUTLASS_HOST_DEVICE static int @@ -1009,6 +1242,236 @@ private: } }; +//////////////////////////////////////////////////////////////////////////////// + +// Parameters for SM90 persistent group scheduler (only used for Grouped Gemms) +template +struct PersistentTileSchedulerSm90GroupParams { + + enum class RasterOrder { + AlongM, + AlongN + }; + + enum class RasterOrderOptions { + Heuristic, + AlongM, + AlongN + }; + + FastDivmodU64Pow2 divmod_cluster_shape_major_{}; + FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; + FastDivmodU64 divmod_batch_{}; + + uint64_t blocks_per_problem_ = 0; + int32_t log_swizzle_size_ = 0; + RasterOrder raster_order_ = RasterOrder::AlongN; + + int32_t groups_ = 0; + ProblemShape* problem_shapes_ = nullptr; + GemmCoord cta_shape_; + + // Version of initialize that takes in as input the number of CTAs in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + void + initialize( + dim3 problem_blocks, + int32_t groups, + ProblemShape* problem_shapes, + GemmCoord cta_shape, + GemmCoord cluster_shape, + KernelHardwareInfo const& hw_info, + int max_swizzle_size, + RasterOrderOptions raster_order_option + ) { + + CUTLASS_UNUSED(hw_info); + + // Round up to nearest multiple of swizzle_size along each mode + auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); + auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + + RasterOrder raster_order = get_rasterization_order( + problem_blocks_m, + problem_blocks_n, + raster_order_option + ); + + // + // Set members + // + groups_ = groups; + problem_shapes_ = problem_shapes; + cta_shape_ = cta_shape; + + blocks_per_problem_ = problem_blocks_m * problem_blocks_n * problem_blocks.z; + log_swizzle_size_ = log_swizzle_size; + raster_order_ = raster_order; + divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n); + + if (raster_order == RasterOrder::AlongN) { + divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n()); + divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.m()); + } + else { + divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m()); + divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n()); + } + } + + // Version of get_tiled_cta_shape_mnl that takes in as input the number of CTAs in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_cta_shape_mnl(GemmCoord cluster_shape, uint32_t cta_m, uint32_t cta_n) { + // Round up to nearest multiple of cluster dim along each mode + auto problem_blocks_m = ((cta_m + cluster_shape.m() - 1) / cluster_shape.m()) * cluster_shape.m(); + auto problem_blocks_n = ((cta_n + cluster_shape.n() - 1) / cluster_shape.n()) * cluster_shape.n(); + + return { + static_cast(problem_blocks_m), + static_cast(problem_blocks_n), + static_cast(1) // Only a single batch per group is currently supported + }; + } + + // Version of get_grid_shape that takes in as input the number of CTAs in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or CTA shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + dim3 problem_blocks, + GemmCoord cluster_shape, + KernelHardwareInfo hw_info, + int max_swizzle_size, + RasterOrderOptions raster_order_option, + bool truncate_by_problem_size=true) { + + int const sm_count = hw_info.sm_count; + + // Round up to nearest multiple of swizzle_size along each mode + auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); + auto problem_blocks_m = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + + int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks.z; + + RasterOrder raster_order = get_rasterization_order( + problem_blocks_m, + problem_blocks_n, + raster_order_option + ); + + dim3 launch_grid; + + if (raster_order == RasterOrder::AlongN) { + launch_grid = dim3(cluster_shape.m(), 1, 1); + } + else { + launch_grid = dim3(1, cluster_shape.n(), 1); + } + + auto possibly_truncate = [&](int x, int y) { + if (truncate_by_problem_size) { + return platform::min(x, y); + } + else { + return x; + } + }; + + // The else path is generic, however, we can avoid some divs if we know cluster size is 1 + auto cluster_size = cluster_shape.m() * cluster_shape.n(); + if (cluster_size == 1) { + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = possibly_truncate(sm_count, problem_blocks_total); + } + else { + launch_grid.x = possibly_truncate(sm_count, problem_blocks_total); + } + } + else { + // Optimal grid size calculation is based on + // GH100: 8 GPCs, 72 TPCs (9 TPCs/GPC), 2 SMs/TPC, 144 SMs per full GPU + // Hence, maximum SMs per GPC = 18 + constexpr int max_sm_per_gpc = 18; + // Provided SM count could possibly be less than the assumed maximum SMs per GPC + auto cluster_size = cluster_shape.m() * cluster_shape.n(); + int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; + int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % cluster_size); + int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; + + // The calculation below allows for larger grid size launch for different GPUs. + int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; + int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % cluster_size); + cta_per_device += max_cta_occupancy_per_residual_gpc; + + cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; + + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = possibly_truncate( + cta_per_device / cluster_shape.m(), + problem_blocks_total / cluster_shape.m()); + } + else { + launch_grid.x = possibly_truncate( + cta_per_device / cluster_shape.n(), + problem_blocks_total / cluster_shape.n()); + } + } + return launch_grid; + } + + CUTLASS_HOST_DEVICE + static int32_t + get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { + int min_cta_dim = platform::min(problem_ctas_m, problem_ctas_n); + if (max_swizzle_size >= 8 && min_cta_dim >= 6) { + return 3; + } + else if (max_swizzle_size >= 4 && min_cta_dim >= 3) { + return 2; + } + else if (max_swizzle_size >= 2 && min_cta_dim >= 2) { + return 1; + } + else { + return 0; + } + } + + CUTLASS_HOST_DEVICE + static RasterOrder + get_rasterization_order( + uint32_t tiles_m, + uint32_t tiles_n, + RasterOrderOptions raster_order_option + ) { + + if (raster_order_option == RasterOrderOptions::Heuristic) { + if (tiles_n > tiles_m) { + return RasterOrder::AlongM; + } + else { + return RasterOrder::AlongN; + } + } + else { + switch (raster_order_option) { + case RasterOrderOptions::AlongN: + return RasterOrder::AlongN; + break; + default: + return RasterOrder::AlongM; + } + } + } +}; + //////////////////////////////////////////////////////////////////////////////// } // namespace detail } // namespace kernel diff --git a/include/cutlass/gemm/thread/mma_sm50.h b/include/cutlass/gemm/thread/mma_sm50.h index 1573e642..55692bdc 100644 --- a/include/cutlass/gemm/thread/mma_sm50.h +++ b/include/cutlass/gemm/thread/mma_sm50.h @@ -114,6 +114,9 @@ struct MmaGeneric { static bool const kMultipleOf2 = ((Shape::kM % 2 == 0) && (Shape::kN % 2 == 0)); + static bool const kAllFp32 = platform::is_same::value && + platform::is_same::value && + platform::is_same::value; // // Methods // @@ -144,11 +147,7 @@ struct MmaGeneric { CUTLASS_PRAGMA_UNROLL for (int k = 0; k < Shape::kK; ++k) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 860) - if (kMultipleOf2 && - platform::is_same::value && - platform::is_same::value && - platform::is_same::value) { - + if (kMultipleOf2 && kAllFp32) { //2x2 zigzag - m and n loops to increment by 2. Inner loop to process 4 multiply-adds in a 2x2 tile. CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Shape::kN; n+=2) { @@ -157,7 +156,7 @@ struct MmaGeneric { for (int m = 0; m < Shape::kM; m+=2) { int m_serpentine = (n % 4) ? (Shape::kM - 2 - m) : m; - + //top-left element in 2x2 tile { MatrixCoord mn(m_serpentine, n); diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 0f10e865..4ec0869b 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -60,9 +60,6 @@ namespace layout { // ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag used for 3-D NWC tensors for 1D conv, only used in 3.x API -class TensorNWC {}; - /// Mapping function for 4-D NHWC tensors. class TensorNHWC { public: @@ -632,6 +629,14 @@ public: } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag used for linearized tensors with shape (NW, C) for 1D conv, only used in 3.x API +class TensorLinearizedNWC {}; +/// Tag used for linearized tensors with shape (NHW, C) for 2D conv, only used in 3.x API +class TensorLinearizedNHWC : public TensorNHWC {}; +/// Tag used for linearized tensors with shape (NDHW, C) for 3D conv, only used in 3.x API +class TensorLinearizedNDHWC : public TensorNDHWC {}; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 4d2faab0..7a8c7a9f 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -32,6 +32,17 @@ \file \brief Boost-like numeric conversion operator for CUTLASS numeric types */ + +/* + Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain + existing integrations of CUTLASS require C++11 host compilers. + + Until this requirement can be lifted, certain headers with this annotation are required + to be remain consistent with C++11 syntax. + + C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. +*/ + #pragma once #if !defined(__CUDACC_RTC__) @@ -2482,6 +2493,1071 @@ struct NumericArrayConverter { #endif // Conditional guards to enable partial specialization for packed integers +namespace detail { + + /* + A helper class that can vectorize a numeric converter with implementation for several vector widths. + + The vector widths must be giving in decreasing order or width, and must be a power of 2. + + The vector converters must produce identical results to the scalar converters for consistency. + */ + class VectorizedConverter { + private: + // Base case to handle remainder elements as scalars. + template + CUTLASS_DEVICE + static void convert_helper( + typename ArrayConverter::result_type& result, + typename ArrayConverter::source_type const& source) { + + using ElementRes = typename ArrayConverter::result_type::Element; + using ElementSrc = typename ArrayConverter::source_type::Element; + // If no more converters, handle the remaining elements as scalars. + constexpr int total_elements = ArrayConverter::result_type::kElements; + constexpr int remainder = total_elements - Offset; + static_assert(remainder == (total_elements % ParentWidth), "Unexpected remainder."); + + typename ArrayConverter::ScalarConverter scalar_converter; + CUTLASS_PRAGMA_UNROLL + for (int i = Offset; i < ArrayConverter::result_type::kElements; ++i) { + result[i] = scalar_converter(ElementSrc(source[i])); + } + } + + template + CUTLASS_DEVICE + static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { + static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); + static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); + static_assert(cutlass::platform::is_same::value, + "ResultVectorArray must have the same type ArrayConverter::result_type"); + static_assert(cutlass::platform::is_same::value, + "SourceVectorArray must have the same type ArrayConverter::result_type"); + static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); + + static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); + + constexpr int vector_width = ResultVectorArray::kElements; + static_assert(ispow2(vector_width), "Vector width must be a power of 2"); + + using ElementRes = typename ArrayConverter::result_type::Element; + using ElementSrc = typename ArrayConverter::source_type::Element; + + constexpr int vector_bits_res = vector_width * cutlass::sizeof_bits::value; + constexpr int vector_bits_src = vector_width * cutlass::sizeof_bits::value; + + static_assert(vector_bits_res % 8 == 0, "Result vector type must be byte addressed."); + static_assert(vector_bits_src % 8 == 0, "Source vector type must be byte addressed."); + + constexpr int vector_offset = Offset / vector_width; + ResultVectorArray* packed_result_vec = reinterpret_cast(&result) + vector_offset; + SourceVectorArray const* packed_source_vec = reinterpret_cast(&source) + vector_offset; + + // Convert the remaining elements as vectors. + constexpr int total_elements = ArrayConverter::result_type::kElements; + constexpr int groups_of_vec = (total_elements - Offset) / vector_width; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < groups_of_vec; ++i) { + packed_result_vec[i] = ArrayConverter::template packed_convert(packed_source_vec[i]); + } + + constexpr int new_offset = Offset + vector_width * groups_of_vec; + // Recurse to handle other vector converters, or the scalar base case. + convert_helper(result, source); + } + + public: + /* + A method to convert vectors of elements using the packed_convert method of the converter. + + Converters using this class must implement packed convert and support 1 or more vector conversions. + */ + template + CUTLASS_DEVICE + static void convert(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { + convert_helper<0, 0, ArrayConverter, ResultVectorArray, SourceVectorArray, OtherVectorArrays...>(result, source); + } + }; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses a lookup table to converts i4 -> e4m3. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 4 or 8 to use private convert dispatch."); + + // Hold FP8 outputs in reg. We need 1 reg for every 4 outputs. + cutlass::AlignedArray r; + + // View the input as reg + uint32_t reg = to_reg(source); + + // Determines if to get from the signed or unsigned candidates + uint32_t sign = (reg & 0x88888888) >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = (reg & 0x77777777); + + // Signed is OR'd with 0x32103210 to find the correct value in the LUT + const uint32_t final_prmt_base = 0x32103210; + + // [0, 1, 2, 3] encoded as FP8 + static constexpr uint32_t POS_E4M3s_REG1 = 0x44403800; + // [4, 5, 6, 7] encoded as FP8 + static constexpr uint32_t POS_E4M3s_REG2 = 0x4E4C4A48; + // [-1, -2, -3, -4] encoded as FP8 + static constexpr uint32_t NEG_E4M3s_REG1 = 0xCACCCED0; + // [-5, -6, -7, -7] encoded as FP8 + static constexpr uint32_t NEG_E4M3s_REG2 = 0xB8C0C4C8; + + + const int iters = PackedSrcType::kElements / 4; + #pragma unroll + for (int ii = 0; ii < iters; ++ii, lut_idx >>=16, sign >>=16) { + uint32_t final_prmt_idx = final_prmt_base | sign; + + // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value + // as the index to prmt. + // It first select both the positive and negative candidates, then uses the sign bit to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 pos_f8s, neg_f8s;\n" + " prmt.b32 pos_f8s, %1, %2, %5;\n" + " prmt.b32 neg_f8s, %3, %4, %5;\n" + " prmt.b32 %0, pos_f8s, neg_f8s, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(POS_E4M3s_REG1), "n"(POS_E4M3s_REG2), "n"(NEG_E4M3s_REG1), "n"(NEG_E4M3s_REG2), + "r"(lut_idx), "r"(final_prmt_idx)); + } + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static void packed_convert_vec(PackedResultType& result, uint32_t src_reg) { + static_assert(offset == 0 || offset == 4, "Invalid offset"); + // Selects one of the bottom int4s and constructs: + // 8388608 + (x + 8) + // 8388608 + 16 * (x + 8) + // 8388608 + 256 * (x + 8) + // 8388608 + 4096 * (x + 8) + uint32_t const and_masks[4] = {0x0000000F, 0x000000F0, 0x00000F00, 0x0000F000}; + uint32_t const xor_masks[4] = {0x4B000008, 0x4B000080, 0x4B000800, 0x4B008000}; + + float const scales[4] = {1.f, 1.f / 16.f, 1.f / 256.f, 1.f / 4096.f}; + float const offsets[4] = {-8388616.f, -524296.f, -32776.f, -2056.f}; + + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + uint32_t* result_as_int = reinterpret_cast(&result); + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < elements_to_convert; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" + "}\n" + : "=r"(result_as_int[offset + ii]) + : "r"(src_reg), "r"(and_masks[ii]), "r"(xor_masks[ii]), "n"(immLut)); + + result[offset + ii] = __fmaf_rn(result[offset + ii], scales[ii], offsets[ii]); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then does a + // subtraction in FP16 for the final result. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 1, 2, 4 or 8 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + PackedResultType r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + constexpr int total_elements = PackedResultType::kElements == 8 ? 4 : PackedResultType::kElements; + packed_convert_vec<0, total_elements>(r, src_reg); + + + if (PackedResultType::kElements == 8) { + uint32_t src_reg_shifted = src_reg >> 16; + packed_convert_vec<4, 4>(r, src_reg_shifted); + } + return r; + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); + + PackedResultType r; + // View the input as reg + uint32_t src_reg = to_reg(source); + static constexpr int fp32_base = 0x4B400000; + uint32_t const prmt_indices[4] = {0x8880, 0x9991, 0xAAA2, 0xBBB3}; + + int* result_as_int = reinterpret_cast(&r); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_as_int[ii]) : "r"(src_reg), "r"(prmt_indices[ii])); + } + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < PackedResultType::kElements; ++ii) + { + result_as_int[ii] += fp32_base; + r[ii] -= reinterpret_cast(fp32_base); + } + + return r; + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); + + PackedResultType r; + // View the input as reg + uint32_t src_reg = to_reg(source); + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores + // the result in r (without introducing extra cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src_reg, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point arithmetic to obtain final result + r[ii] -= 8388608.f; + } + + return r; + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses bit tricks to construct a known FP16 number, then does a + // subtraction in FP16 for the final result. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want the documented (& 0x7) on the index. NVCC + // might be able to optimize it out since the index is a constexpr, but we choose to be safe about it here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); + } + + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing + // 1024 + x + 8 OR 1024 + 16 * (x + 8), then using hfma to subtract 1032 from that + // 2) Adds 8 to the int4 value that we will process in the FP16 (for uint4, we can simply avoid this step) + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64806408; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // For the high FP16: + // Divide by 16 {packed as a operand} to get: + // 64 + (x + 8) + // x + 72 + // Subtract 72 {packed as c operand} to get x + // For the low FP16: + // 1024 + (x + 8) + // x + 1032 + // So, we subtract 1032 {packed as c operand} to get x + + // {-72, -1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; + // {1 / 16, 1} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + // Scale and subtract the FP16s to get the original int4 number as FP16. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses bit tricks to construct a known FP16 number, then does a + // subtraction in FP16 for the final result. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + #if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) + auto result = reinterpret_cast(r); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < PackedResultType::kElements; ++i) { + int16_t tmp = source[i] + 26112 /* 0x6600 */; + result[i] = reinterpret_cast(tmp) - 1536.0_hf; + } + #endif + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t const prmt_indices[2] = {0x9180, 0xB3A2}; + + // Pack s8x2 (s8[1], s8[0]) -> s16x2 (sext.s8[1], sext.s8[0]) + // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt) + // The inline ptx below uses `msb=0` and `msb=1` from the above link to sign-extend the sign bit in 0, 1, 2, 3 bytes of s8x4 + // into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively. + // Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't achieve the same result and doesn't sign-extend the sign bit. + // Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from s8x2 to s16x2. + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(r[ii]) : "r"(src_reg), "r"(prmt_indices[ii])); + } + + // In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve + // the same result as add.s16x2 instruction. + // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) + // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to + // three predefined constant values as follows: + // ta = 0xF0; + // tb = 0xCC; + // tc = 0xAA; + // kImmLut = F(ta, tb, tc); + // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA + static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + // The bit-wise operation executed below is `r[ii] = (r[ii] & 0x03FF03FF) ^ 0x66006600;` + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : + "=r"(r[ii]) : "r"(r[ii]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); + } + + static constexpr uint32_t bias_rep = 0x66006600; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(r[ii]) : "r"(src_reg), "n"(start_byte_for_fp16), "r"(prmt_indices[ii])); + } + + static constexpr uint32_t bias_rep = 0x64006400; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_8 = Array; + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_8 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_8 const& source) { + return reinterpret_cast(source); + } + + // The core converter uses bit tricks to construct a known FP16 number, then does a + // subtraction in FP16 for the final result. + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); + + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = cutlass::AlignedArray; + RegArray r; + + // View the input as reg + uint32_t src_reg = to_reg(source); + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, "Too many inputs for BF16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the FP16 magic_num. We will be constructing + // 128 + (x + 8) and subtracting 136 to get x + static constexpr uint32_t xor_mask = 0x43084308; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); + + NumericArrayConverter convert_int8_to_f32; + Array tmp = convert_int8_to_f32(source); + NumericArrayConverter convert_f32_to_bf16; + return convert_f32_to_bf16(tmp); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + +private: + using result_type_packed_4 = Array; + using result_type_packed_2 = Array; + using source_type_packed_4 = Array; + using source_type_packed_2 = Array; + + using ScalarConverter = NumericConverter; + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_2 const& source) { + return static_cast( + reinterpret_cast(source)); + } + + CUTLASS_DEVICE + static uint32_t to_reg(source_type_packed_4 const& source) { + return reinterpret_cast(source); + } + + template + CUTLASS_DEVICE + static PackedResultType packed_convert(PackedSrcType const &source) { + + static_assert((platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); + + NumericArrayConverter convert_uint8_to_f32; + Array tmp = convert_uint8_to_f32(source); + NumericArrayConverter convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + +public: + CUTLASS_DEVICE + static result_type convert(source_type const &source) { + result_type result; + using ConverterType = NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +#endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + ///////////////////////////////////////////////////////////////////////////////////////////////// /// FastNumericArrayConverter only works when the source is within center range. @@ -2507,12 +3583,10 @@ struct FastNumericArrayConverter { }; /// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter::is_integer> -> { +template +struct FastNumericArrayConverter { using result_type = Array; - using source_type = Array; + using source_type = Array; static FloatRoundStyle const round_style = Round; CUTLASS_DEVICE @@ -2592,223 +3666,6 @@ struct FastNumericArrayConverter { result_type operator()(source_type const &s) const { return convert(s); } }; -/// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - - #if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - int16_t tmp = source[i] + 26112 /* 0x6600 */; - result[i] = reinterpret_cast(tmp) - 1536.0_hf; - } - #endif - - // Vectorized s8->f16 conversion using packed instructions - uint32_t const* source_ptr = reinterpret_cast(&source); - uint32_t* result_ptr = reinterpret_cast(&result); - - // Pack s8x2 (s8[1], s8[0]) -> s16x2 (sext.s8[1], sext.s8[0]) - // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt) - // The inline ptx below uses `msb=0` and `msb=1` from the above link to sign extend the sign-bit in 0, 1, 2, 3 bytes of s8x4 - // into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively. - // Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't acheive the same and doesn't sign extend the sign-bit. - // Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from s8x2 to s16x2. - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[0]) : "r"(source_ptr[0]), "n"(0x9180)); - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[1]) : "r"(source_ptr[0]), "n"(0xB3A2)); - - // In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve - // the same result as add.s16x2 instruction. - // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) - // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to - // three predefined constant values as follows: - // ta = 0xF0; - // tb = 0xCC; - // tc = 0xAA; - // kImmLut = F(ta, tb, tc); - // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA - static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; - - // The bit-wise operation executed below is `result_ptr[0] = (result_ptr[0] & 0x03FF03FF) ^ 0x66006600;` - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : - "=r"(result_ptr[0]) : "r"(result_ptr[0]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); - // The bit-wise operation executed below is `result_ptr[1] = (result_ptr[1] & 0x03FF03FF) ^ 0x66006600;` - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : - "=r"(result_ptr[1]) : "r"(result_ptr[1]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); - - // Packed sub.f16x2 with magic number to obtain final converted result - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - - uint32_t const* source_ptr = reinterpret_cast(&source); - uint32_t* result_ptr = reinterpret_cast(&result); - - result_ptr[0] = __byte_perm(source_ptr[0], 0x0, 0x4140); - result_ptr[1] = __byte_perm(source_ptr[0], 0x0, 0x4342); - - asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600)); - asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600)); - - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter { - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - Array tmp; - - uint32_t const* source_ptr = reinterpret_cast(&source); - uint32_t* tmp_ptr = reinterpret_cast(&tmp); - - // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores - // the result in tmp (without introducing extra cvt.u32.u8 instruction) - tmp_ptr[0] = __byte_perm(source_ptr[0], 0x4B000000, 0x7650); - tmp_ptr[1] = __byte_perm(source_ptr[0], 0x4B000000, 0x7651); - tmp_ptr[2] = __byte_perm(source_ptr[0], 0x4B000000, 0x7652); - tmp_ptr[3] = __byte_perm(source_ptr[0], 0x4B000000, 0x7653); - - // Subtract the magic number 0x4B000000 from tmp in floating-point arithmetic to obtain final result - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < 4; ++i) { - tmp[i] = reinterpret_cast(tmp_ptr[i]) - 8388608.f; - } - - // on 3456x4096x8192 runs at 158 TFLOP/s - // Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction - NumericArrayConverter convert_f32_to_bf16; - result = convert_f32_to_bf16(tmp); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template -struct FastNumericArrayConverter { - using result_type = Array; - using source_type = Array; - - using intermediate_float_type = Array; - using intermediate_int32_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - result_type result; - intermediate_float_type tmp; - - uint32_t const* source_ptr = reinterpret_cast(&source); - uint32_t* tmp_ptr = reinterpret_cast(&tmp); - - // s8x4 (s[3], s[2], s8[1], s8[0]) -> s16x4 (sext.s8[3], sext.s8[2], sext.s8[1], sext.s8[0]) - // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt) - // The inline ptx below uses `msb=0` and `msb=1` from the above link to sext the sign-bit in 0, 1, 2, 3 bytes of s8x4 - // sext without unpacking each s8 out of s8x4 into a separate register a.ka. without using shifts (SHFL). - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[0]) : "r"(source_ptr[0]), "n"(0x8880)); - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[1]) : "r"(source_ptr[0]), "n"(0x9991)); - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[2]) : "r"(source_ptr[0]), "n"(0xAAA2)); - asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[3]) : "r"(source_ptr[0]), "n"(0xBBB3)); - - // Convert s32x4 to f32x4 using fast numeric array converter - FastNumericArrayConverter convert_s32_to_f32_; - tmp = convert_s32_to_f32_(reinterpret_cast(tmp[0])); - - // Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction - NumericArrayConverter convert_f32_to_bf16_; - result = convert_f32_to_bf16_(tmp); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for FastNumericArrayConverter to vectorize over 4 elements. -/// source `S` as 8b integers (S8 or U8) -> destination `T` as 16b floating-point (F16 or BF16) -template -struct FastNumericArrayConverter::value || platform::is_same::value) && - (platform::is_same::value || platform::is_same::value)>::type> { - static_assert(!(N % 4), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_DEVICE - static result_type convert(source_type const &source) { - FastNumericArrayConverter convert_vector_; - result_type result; - - Array *result_ptr = - reinterpret_cast *>(&result); - Array const *source_ptr = - reinterpret_cast const *>(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / 4; ++i) { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const &s) const { return convert(s); } - -}; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines preferred rounding mode for a pair of types diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 4a50328d..783fc919 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -48,7 +48,8 @@ using namespace cute; enum class BarrierStatus : uint32_t { WaitAgain = 0u, - WaitDone = 1u + WaitDone = 1u, + WaitOnly = 2u }; class ArrivalToken { @@ -81,6 +82,16 @@ private: friend bool operator==(const BarrierStatus& left, const ArrivalToken& right) { return left == right.get(); } + + CUTLASS_HOST_DEVICE + friend bool operator!=(const ArrivalToken& left, const BarrierStatus& right) { + return left.get() != right; + } + + CUTLASS_HOST_DEVICE + friend bool operator!=(const BarrierStatus& left, const ArrivalToken& right) { + return left != right.get(); + } }; class ProducerToken : public ArrivalToken { @@ -188,13 +199,9 @@ PipelineState make_producer_start_state() { // Assumptions : Constructor is visible Cluster-wide (as it needs a Cluster-Sync) // We have exactly one thread elected in the Producer as the "leader" // Currently, it is optional to elect a leader for the Consumers -template < - int Stages_, - class ClusterShape_ -> +template class PipelineTmaAsync { public : - using ClusterShape = ClusterShape_; using FullBarrier = cutlass::arch::ClusterTransactionBarrier; using EmptyBarrier = cutlass::arch::ClusterBarrier; using ProducerBarrierType = FullBarrier::ValueType; @@ -222,15 +229,15 @@ public : }; // Constructor + template CUTLASS_DEVICE - PipelineTmaAsync(SharedStorage& storage, Params params) + PipelineTmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape) : params_(params) , full_barrier_ptr_(&storage.full_barrier_[0]) , empty_barrier_ptr_(&storage.empty_barrier_[0]) { int warp_idx = canonical_warp_idx(); int lane_predicate = cute::elect_one_sync(); - auto cluster_shape = ClusterShape{}; if (warp_idx == 0 && lane_predicate == 1) { // Barrier FULL init @@ -283,7 +290,8 @@ public : is_signalling_thread_ &= dst_blockid_ < cluster_size; is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); } - + + template CUTLASS_DEVICE bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) || @@ -332,7 +340,7 @@ public : CUTLASS_DEVICE void producer_tail(PipelineState state) { for (int count = 0; count < Stages; ++count) { - producer_acquire(state); + producer_acquire(state, {BarrierStatus::WaitOnly}); ++state; } } @@ -388,9 +396,12 @@ private : CUTLASS_DEVICE void producer_acquire(uint32_t stage, uint32_t phase, ProducerToken barrier_token) { - if (barrier_token == BarrierStatus::WaitAgain) { + if (barrier_token != BarrierStatus::WaitDone) { empty_barrier_ptr_[stage].wait(phase); } + if (barrier_token == BarrierStatus::WaitOnly) { + return; + } if (params_.is_leader) { full_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); @@ -417,7 +428,7 @@ private : full_barrier_ptr_[stage].complete_transaction(bytes); // STEP 2 : Commit to other blocks in our cluster - auto cluster_shape = ClusterShape{}; + auto cluster_shape = cute::cluster_shape(); Layout block_layout_in_cluster = make_layout(cluster_shape); dim3 local_block_id = cute::block_id_in_cluster(); diff --git a/include/cutlass/subbyte_reference.h b/include/cutlass/subbyte_reference.h index 8ef6d187..0b7191ae 100644 --- a/include/cutlass/subbyte_reference.h +++ b/include/cutlass/subbyte_reference.h @@ -394,7 +394,13 @@ public: /// Unpacks an element from memory CUTLASS_HOST_DEVICE Element get() const { - Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); + uint8_t const* byte_ptr = reinterpret_cast(ptr_); + // Convert offset in elements to offset in bytes + constexpr int elements_per_byte = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; + byte_ptr += offset_ / elements_per_byte; + // Offset of element within a byte + int byte_offset = offset_ % elements_per_byte; + uint8_t item = uint8_t((*byte_ptr >> (byte_offset * cutlass::sizeof_bits::value)) & kMask); return reinterpret_cast(item); } @@ -607,6 +613,7 @@ public: ///////////////////////////////////////////////////////////////////////////////////////////////// +template using _war = T; template < typename Element_, /// CUTLASS numeric element type. typename Storage_ /// Underlying basic storage type. @@ -647,7 +654,7 @@ private: StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); /// Pointer to array containing element - StorageVecPointer ptr_; + _war ptr_; /// Offset (in units of elements) from pointer. /// @@ -979,6 +986,7 @@ public: } }; +template using _war = T; template < typename Element_, /// CUTLASS numeric element type. typename Storage_ /// Underlying storage type. Must be able to hold an integer @@ -1019,7 +1027,7 @@ private: StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); /// Pointer to array containing element - StorageVecPointer ptr_; + _war ptr_; /// Offset (in units of elements) from pointer. /// @@ -1276,6 +1284,10 @@ struct ReferenceFactory; template struct ReferenceFactory { + + ///! Number of elements per storage vector + static int const kElementsPerVector = 1; + CUTLASS_HOST_DEVICE static Element &get(Element *ptr, int64_t offset) { return ptr[offset]; @@ -1285,10 +1297,25 @@ struct ReferenceFactory { static Element const &get(Element const *ptr, int64_t offset) { return ptr[offset]; } + + CUTLASS_HOST_DEVICE + static Element *add_pointer_offset(Element *ptr, int64_t offset) { + return ptr + offset; + } + + CUTLASS_HOST_DEVICE + static Element const *add_pointer_offset(Element const *ptr, int64_t offset) { + return ptr + offset; + } }; template struct ReferenceFactory { + + // + // Static methods + // + CUTLASS_HOST_DEVICE static SubbyteReference get(Element *ptr, int64_t offset) { return SubbyteReference(ptr, offset); @@ -1299,6 +1326,22 @@ struct ReferenceFactory { int64_t offset) { return ConstSubbyteReference(ptr, offset); } + + /// Helper to add an offset in number of elements, assuming this offset is divisible + /// by the vector size. + CUTLASS_HOST_DEVICE + static Element *add_pointer_offset(Element *ptr, int64_t offset_in_elements) { + + return ptr + offset_in_elements * sizeof_bits::value / sizeof(Element) / 8; + } + + /// Helper to add an offset in number of elements, assuming this offset is divisible + /// by the vector size. + CUTLASS_HOST_DEVICE + static Element const *add_pointer_offset(Element const *ptr, int64_t offset_in_elements) { + + return ptr + offset_in_elements * sizeof_bits::value / sizeof(Element) / 8; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp index ac67da08..8b8058b8 100644 --- a/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp +++ b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp @@ -103,14 +103,14 @@ public: using SmemLayoutB = SmemLayoutB_; using SmemLayoutAtomB = SmemLayoutAtomB_; using ElementB = ElementB_; - - constexpr CUTLASS_HOST_DEVICE + + constexpr CUTLASS_HOST_DEVICE NoTranspositionOperandB( - int, - int, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, + int, + int, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, ElementB) { } template < @@ -148,12 +148,12 @@ public: constexpr CUTLASS_HOST_DEVICE UniversalTranspositionOperandB( - int warp_idx_, - int warp_group_thread_idx_, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB) + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) : warp_idx(warp_idx_) , warp_group_thread_idx(warp_group_thread_idx_) { } @@ -168,9 +168,9 @@ public: return; } - constexpr int NumMathWarpGroup = size(TiledMma{}) / NumThreadsPerWarpGroup; - static_assert(NumMathWarpGroup == 1 || - (!detail::use_universal_transposition() && NumMathWarpGroup == 2), + constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static_assert(NumMathWarpGroup == 1 || + (!detail::use_universal_transposition() && NumMathWarpGroup == 2), "Wrong math warp group number for TransposeB"); constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. @@ -234,14 +234,14 @@ public: if (step == 0) { // SMEM fence to make sure B is transposed before math cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); } } CUTLASS_DEVICE void synchronize() { // SMEM fence to make sure B is transposed before math cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); } template < @@ -251,7 +251,7 @@ public: TensorSmemB const& sB, TensorTransposedSmemB const& gmma_sB, int read_stage) { - + this->operator()(sB, gmma_sB, read_stage, 0); synchronize(); @@ -267,7 +267,7 @@ template< class SmemLayoutB_, class SmemLayoutAtomB_, class ElementB_> -class AsyncTranspositionOperandB { +class AsyncTranspositionOperandB { public: using TiledMma = TiledMma_; @@ -276,9 +276,9 @@ public: using ElementB = ElementB_; static constexpr int Steps = 2; - static constexpr int NumMathWarpGroup = size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; - static_assert(NumMathWarpGroup <= 2, + static_assert(NumMathWarpGroup <= 2, "Wrong math warp group number for TransposeB"); static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; @@ -303,23 +303,23 @@ public: "Copy size must evenly divide SMEM tile."); static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; - static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK, + static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK, "Need to be able to transpose first k-block in the first step"); - constexpr CUTLASS_HOST_DEVICE + constexpr CUTLASS_HOST_DEVICE AsyncTranspositionOperandB( - int warp_idx_, - int warp_group_thread_idx_, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB) + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) : warp_idx(warp_idx_) , warp_group_thread_idx(warp_group_thread_idx_) , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) - , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) - , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } template < @@ -328,12 +328,12 @@ public: CUTLASS_DEVICE void operator()( TensorSmemB const& sB, TensorTransposedSmemB const& gmma_sB, - int read_stage, int current_step) + int read_stage, int current_step) { if (current_step >= StepsPerWarpGroup) { return; } - + static constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); ////////////////////////////////////////////////////////////////////////////////////////////////////////////// /// A warp group uses 2 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. @@ -384,11 +384,11 @@ public: }; [[maybe_unused]] int step = current_step * NumMathWarpGroup; - if constexpr (NumMathWarpGroup == 2) { - // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + if constexpr (NumMathWarpGroup == 2) { + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. step += warp_idx / (NumWarpsPerWarpGroup * 2); } - + int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT >> (NumBitsPerStep * current_step); int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT >> (NumBitsPerStep * current_step); @@ -396,7 +396,7 @@ public: tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); } - + // decoding the warp tile coord. int warp_tile0_n, warp_tile0_k; if constexpr (StepsPerWarpGroup <= NumStepsEncoded) { @@ -412,7 +412,7 @@ public: CUTLASS_PRAGMA_UNROLL for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { - + static_assert(TilesPerWarp == 2); // [warp_tile][n/k] @@ -427,7 +427,7 @@ public: Tensor tCsB = sB_thr_copy.partition_S( flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) ); // (CPY, CPY_N, CPY_K) - + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); } @@ -442,20 +442,20 @@ public: copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); } - } // loop warp_group_tile + } // loop warp_group_tile } CUTLASS_DEVICE void synchronize(int step) { if (step < StepsPerWarpGroup) { // SMEM fence to make sure B is transposed before math cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); } } CUTLASS_DEVICE void synchronize() { cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); } template < @@ -465,7 +465,7 @@ public: TensorSmemB const& sB, TensorTransposedSmemB const& gmma_sB, int read_stage) { - + CUTLASS_PRAGMA_UNROLL for(int i = 0; i < StepsPerWarpGroup; ++i) { this->operator()(sB, gmma_sB, read_stage, i); @@ -486,7 +486,7 @@ template< class SmemLayoutB_, class SmemLayoutAtomB_, class ElementB_> -class AsyncTranspositionOperandB_1BElementB { +class AsyncTranspositionOperandB_1BElementB { public: static_assert(sizeof(ElementB_) == 1); @@ -495,11 +495,11 @@ public: using SmemLayoutB = SmemLayoutB_; using SmemLayoutAtomB = SmemLayoutAtomB_; using ElementB = ElementB_; - + static constexpr int Steps = 8; - static constexpr int NumMathWarpGroup = size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; - static_assert(NumMathWarpGroup <= 2, + static_assert(NumMathWarpGroup <= 2, "Wrong math warp group number for TransposeB"); static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; @@ -524,21 +524,20 @@ public: "Copy size must evenly divide SMEM tile."); static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; - - constexpr CUTLASS_HOST_DEVICE + constexpr CUTLASS_HOST_DEVICE AsyncTranspositionOperandB_1BElementB( - int warp_idx_, - int warp_group_thread_idx_, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, - ElementB) + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) : warp_idx(warp_idx_) , warp_group_thread_idx(warp_group_thread_idx_) , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) - , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) - , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } template < @@ -547,7 +546,7 @@ public: CUTLASS_DEVICE void operator()( TensorSmemB const& sB, TensorTransposedSmemB const& gmma_sB, - int read_stage, int current_step) + int read_stage, int current_step) { if (current_step > 0) { return; @@ -628,7 +627,7 @@ public: CUTLASS_PRAGMA_NO_UNROLL for (int step_per_warp_group = 0; step_per_warp_group < StepsPerWarpGroup; ++step_per_warp_group) { - // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. int step = step_per_warp_group * NumMathWarpGroup + warp_idx / (NumWarpsPerWarpGroup * 2); // decoding the warp tile coord. int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; @@ -653,7 +652,7 @@ public: Tensor tCsB = sB_thr_copy.partition_S( flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) ); // (CPY, CPY_N, CPY_K) - + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); } @@ -675,13 +674,13 @@ public: if (step == 0) { // SMEM fence to make sure B is transposed before math cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); } } CUTLASS_DEVICE void synchronize() { cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); } template < @@ -711,35 +710,35 @@ template< class ElementB, bool TransposeB > -constexpr CUTLASS_HOST_DEVICE -auto +constexpr CUTLASS_HOST_DEVICE +auto make_transpose_operand_b( - int warp_idx, - int warp_group_thread_idx, - TiledMma, - SmemLayoutB, - SmemLayoutAtomB, + int warp_idx, + int warp_group_thread_idx, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, ElementB, cute::bool_constant) { if constexpr (!TransposeB) { return NoTranspositionOperandB( - warp_idx, warp_group_thread_idx, TiledMma{}, + warp_idx, warp_group_thread_idx, TiledMma{}, SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); } else if constexpr (use_universal_transposition()) { return UniversalTranspositionOperandB( - warp_idx, warp_group_thread_idx, TiledMma{}, + warp_idx, warp_group_thread_idx, TiledMma{}, SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); } else if constexpr (sizeof(ElementB) == 1) { return AsyncTranspositionOperandB_1BElementB( - warp_idx, warp_group_thread_idx, TiledMma{}, + warp_idx, warp_group_thread_idx, TiledMma{}, SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); } else { return AsyncTranspositionOperandB( - warp_idx, warp_group_thread_idx, TiledMma{}, + warp_idx, warp_group_thread_idx, TiledMma{}, SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); } } diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h index e509ddc6..616d45d9 100755 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h @@ -71,7 +71,6 @@ struct PredicatedTileAccessIteratorDesc { // Methods // - CUTLASS_HOST_DEVICE PredicatedTileAccessIteratorDesc() = default; CUTLASS_HOST_DEVICE @@ -279,7 +278,6 @@ struct PredicatedTileAccessIteratorParams { return initialize(LongIndex(stride), desc); } - CUTLASS_HOST_DEVICE PredicatedTileAccessIteratorParams() = default; CUTLASS_HOST_DEVICE diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index 155e0425..9ca7dcd9 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -56,9 +56,9 @@ #include "cutlass/cutlass.h" /// Optionally enable GCC's built-in type -#if (defined(__x86_64) || defined (__aarch64__)) && !defined(__CUDA_ARCH__) && defined(__GNUC__) +#if (defined(__x86_64) || defined (__aarch64__)) && !(defined(__CUDA_ARCH__) && ((__CUDACC_VER_MAJOR__ <= 10) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ <= 4)))) && defined(__GNUC__) #define CUTLASS_UINT128_NATIVE -#elif defined(_MSC_VER) && defined(_M_AMD64) && !defined(__CUDA_ARCH__) +#elif defined(_MSC_VER) && defined(_M_AMD64) && !(defined(__CUDA_ARCH__) && ((__CUDACC_VER_MAJOR__ <= 10) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ <= 4)))) #define CUTLASS_INT128_ARITHMETIC #include #if _MSC_VER >= 1920 diff --git a/media/docs/build/building_with_clang_as_host_compiler.md b/media/docs/build/building_with_clang_as_host_compiler.md index 6a46e82f..9abf6d52 100644 --- a/media/docs/build/building_with_clang_as_host_compiler.md +++ b/media/docs/build/building_with_clang_as_host_compiler.md @@ -9,7 +9,8 @@ Clang as both host and device compiler ("CUDA Clang"). # Software prerequisites -1. Clang (tested with Clang 14) +1. Clang (regularly tested with Clang 14; + occasionally tested with Clang 10 and greater) 2. CUDA Toolkit (tested with 12.2; other versions likely work) @@ -32,14 +33,18 @@ is the following error when attempting to use clang: # Running CMake -The Clang build requires specifying the following three CMake options. +## Required CMake options -* `CMAKE_CXX_COMPILER=clang++` -* `CMAKE_CUDA_HOST_COMPILER=clang++` +The Clang build requires specifying the following CMake options. +Replace `` with the path to your `clang++` executable, +and replace `` with the path to your `clang` executable +(which must have the same version as your `clang++` executable). +You may use `clang++` resp. `clang` directly if they are in your `PATH`. -* `CMAKE_C_COMPILER=clang` +* `CMAKE_CXX_COMPILER=` +* `CMAKE_CUDA_HOST_COMPILER=` +* `CMAKE_C_COMPILER=` -This assumes that `clang++` and `clang` are in the user's `PATH`. Please note that both `CMAKE_CXX_COMPILER` and `CMAKE_C_COMPILER` must be set, even though CUTLASS is a C++ project, not a C project. @@ -51,3 +56,4 @@ if `${PATH_TO_CUDA_TOOLKIT}` is the CUDA Toolkit directory, then one can set `CMAKE_CUDA_COMPILER` as follows. * `CMAKE_CUDA_COMPILER=${PATH_TO_CUDA_TOOLKIT}/bin/nvcc` + diff --git a/media/docs/cute/0z_tma_tensors.md b/media/docs/cute/0z_tma_tensors.md index 3e0d0b1c..93f0db25 100644 --- a/media/docs/cute/0z_tma_tensors.md +++ b/media/docs/cute/0z_tma_tensors.md @@ -1,83 +1,233 @@ -# TMA tensors +# CuTe TMA Tensors -TMA tensors have three differences from -"ordinary" global memory tensors. +Along your travels, you may find strange looking CuTe Tensors that are printed as something like +``` +ArithTuple(0,_0,_0,_0) o ((_128,_64),2,3,1):((_1@0,_1@1),_64@1,_1@2,_1@3) +``` +What is an `ArithTuple`? Are those tensor strides? What do those mean? What is this for? -1. The tensor's iterator stores a base coordinate, - not a pointer. +This documentation intends to answer those questions and introduce some of the more advanced features of CuTe. -2. The tensor's actual global memory pointer - does not live in the tensor. - Instead, it lives in a TMA descriptor, - which is stored in the TMA `Copy_Traits` specialization. +# Introduction to TMA instructions -3. The tensor's strides aren't just integers. - Instead, they are linear combinations of "basis functions." +The Tensor Memory Accelerator (TMA) is a set of instructions for copying possibly multidimensional arrays between global and shared memory. TMA was introduced in the Hopper architecture. A single TMA instruction can copy an entire tile of data all at once. As a result, the hardware no longer needs to compute individual memory addresses and issue a separate copy instruction for each element of the tile. -The following sections will elaborate these differences. +To accomplish this, the TMA instruction is given a *TMA descriptor*, which is a packed representation of a multidimensional tensor in global memory with 1, 2, 3, 4, or 5 dimensions. The TMA descriptor holds -## Iterator stores a base coordinate, not a pointer +* the base pointer of the tensor; -"Ordinary" tensors of global memory have an iterator type -(the "Engine" template parameter) that wraps a pointer. -For example, `gmem_ptr` wraps a `T*`. -A TMA tensor's iterator type is `ArithmeticTupleIterator`. -`ArithmeticTupleIterator` stores a coordinate -(a tuple of integers) instead of a pointer. -The coordinate is represented as an `ArithmeticTuple`, -which is just a (public subclass of) `cute::tuple` -that has an overloaded `operator+`. -The sum of two tuples is the tuple of the sum of the elements. +* the data type of the tensor's elements (e.g., `int`, `float`, `double`, or `half`); -When we perform the TMA load or store, -the iterator's coordinate goes into the PTX instruction. -(For TMA specializations of `Copy_Traits`, -this happens in the `private` member function `copy_unpack_`.) -The coordinate represents the tensor's "base coordinate." -For tiled TMA, the base coordinate of the whole tensor -might start out as (0, 0, ..., 0). However, slicing the tensor -might result in a different base coordinate. -For im2col TMA load, the base coordinate is the lower corner. +* the size of each dimension; -## Pointer lives in TMA descriptor, not tensor +* the stride within each dimension; and -The TMA descriptor has the actual pointer to global memory in it. -Storing the TMA descriptor in the tensor would make tensors -expensive to copy and slice, as the TMA descriptor is 128 bytes. -Instead, we store the TMA descriptor -in the `Copy_Traits` specialization. +* other flags representing the smem box size, smem swizzling patterns, and out-of-bounds access behavior. -## Tensor's strides aren't just integers +This descriptor must be created on the host before kernel execution. +It is shared between all thread blocks that will be issuing TMA instructions. +Once inside the kernel, the TMA is executed with the following parameters: -For "ordinary" tensors, the layout takes a coordinate -`(i, j)` as input, and returns a single integer offset `k`. -The resulting pointer-to-element -is the base pointer, plus the offset k. -However, TMA loads and stores don't take a pointer. -They take a TMA descriptor, and a coordinate `(i, j)`. -Building the strides out of "basis functions" -is the trick to make the layout return a coordinate -- -a tuple of integers -- instead of just a single integer offset. -A "basis function" for strides -is a lot like a basis function for Euclidean space, -except that strides' basis functions can be hierarchical. +* pointer to the TMA descriptor; + +* pointer to the SMEM; and + +* coordinates into the GMEM tensor represented within the TMA descriptor. + +For example, the interface for TMA-store with 3-D coordinates looks like this. + +```cpp +struct SM90_TMA_STORE_3D { + CUTE_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) { + // ... invoke CUDA PTX instruction ... + } +}; +``` + +We observe that the TMA instruction does not directly consume pointers to global memory. Indeed, the global memory pointer is contained in the descriptor, is considered constant, and is NOT a separate parameter to the TMA instruction. Instead, the TMA consumes TMA coordinates into the TMA's view of global memory that is defined in the TMA descriptor. + +That means that an ordinary CuTe Tensor that stores a GMEM pointer and computes offsets and new GMEM pointers is useless to the TMA. + +What do we do? + +# Building a TMA Tensor + +## Implicit CuTe Tensors + +All CuTe Tensors are compositions of Layouts and Iterators. An ordinary global memory tensor's iterator is its global memory pointer. However, a CuTe Tensor's iterator doesn't have to be a pointer; it can be any random-access iterator. + +One example of such an iterator is a *counting iterator*. +This represents a possibly infinite sequence of integers that starts at some value. +We call the members of this sequence *implicit integers*, +because the sequence is not explicitly stored in memory. +The iterator just stores its current value. + +We can use a counting iterator to create a tensor of implicit integers, +```cpp +Tensor A = make_tensor(counting_iterator(42), make_shape(4,5)); +print_tensor(A); +``` +which outputs +``` +counting_iter(42) o (4,5):(_1,4): + 42 46 50 54 58 + 43 47 51 55 59 + 44 48 52 56 60 + 45 49 53 57 61 +``` +This tensor maps logical coordinates to on-the-fly computed integers. Because it's still a CuTe Tensor, it can still be tiled and partitioned and sliced just like a normal tensor by accumulating integer offsets into the iterator. + +But the TMA doesn't consume pointers or integers, it consumes coordinates. Can we make a tensor of implicit TMA +coordinates for the TMA instruction to consume? If so, then we could presumably also tile and partition and slice that tensor of coordinates so that we would always have the right TMA coordinate to give to the instruction. + +## ArithTupleIterators and ArithTuples + +First, we build a `counting_iterator` equivalent for TMA coordinates. It should support + +* dereference to a TMA coordinate, and + +* offset by another TMA coordinate. + +We'll call this an `ArithmeticTupleIterator`. It stores a coordinate (a tuple of integers) that is represented as an `ArithmeticTuple`. The `ArithmeticTuple` is simply a (public subclass of) `cute::tuple` that has an overloaded `operator+` so that it can be offset by another tuple. The sum of two tuples is the tuple of the sum of the elements. + +Now similar to `counting_iterator(42)` we can create an implicit "iterator" (but without increment or other common iterator operations) over tuples that can be dereferenced and offset by other tuples +```cpp +ArithmeticTupleIterator citer_1 = make_inttuple_iter(42, Int<2>{}, Int<7>{}); +ArithmeticTupleIterator citer_2 = citer_1 + make_tuple(Int<0>{}, 5, Int<2>{}); +print(*citer_2); +``` +which outputs +``` +(42,7,_9) +``` + +A TMA Tensor can use an iterator like this to store the current TMA coordinate "offset". The "offset" here is in quotes because it's clearly not a normal 1-D array offset or pointer. + +In summary, one creates a TMA descriptor for the *whole global memory tensor*. The TMA descriptor defines a view into that tensor and the instruction takes TMA coordinates into that view. In order to generate and track those TMA coordinates, we define an implicit CuTe Tensor of TMA coordinates that can be tiled, sliced, and partitioned the exact same way as an ordinary CuTe Tensor. + +We can now track and offset TMA coordinates with this iterator, but how do we get CuTe Layouts to generate non-integer offsets? + +## Strides aren't just integers + +Ordinary tensors have a layout that maps +a logical coordinate `(i,j)` into a 1-D linear index `k`. +This mapping is the inner-product of the coordinate with the strides. + +TMA Tensors hold iterators of TMA coordinates. +Thus, a TMA Tensor's Layout must map a logical coordinate +to a TMA coordinate, rather than to a 1-D linear index. + +To do this, we can abstract what a stride is. Strides need not be integers, but rather any algebraic object that supports inner-product with the integers (the logical coordinate). The obvious choice is the `ArithmeticTuple` we used earlier since they can be added to each other, but this time additionally equipped with an `operator*` so it can also be scaled by an integer. + +### Aside: Integer-module strides + +A group of objects that support addition between elements and product between elements and integers is called an integer-module. + +Formally, an integer-module is an abelian group `(M,+)` equipped with `Z*M -> M`, where `Z` are the integers. That is, an integer-module `M` is +a group that supports inner products with the integers. +The integers are an integer-module. +Rank-R tuples of integers are an integer-module. + +In principle, layout strides may be any integer-module. + +### Basis elements + +CuTe's basis elements live in the header file `cute/numeric/arithmetic_tuple.hpp`. +To make it easy to create `ArithmeticTuple`s that can be used as strides, CuTe defines normalized basis elements using the `E` type alias. "Normalized" means that the scaling factor of the basis element is the compile-time integer 1. + +| C++ object | Description | String representation | +| --- | --- | --- | +| `E<>{}` | `1` | `1` | +| `E<0>{}` | `(1,0,...)` | `1@0` | +| `E<1>{}` | `(0,1,0,...)` | `1@1` | +| `E<0,1>{}` | `((0,1,0,...),0,...)` | `1@1@0` | +| `E<1,0>{}` | `(0,(1,0,...),0,...)` | `1@0@1` | + +The "description" column in the above table +interprets each basis element as an infinite tuple of integers, +where all the tuple's entries not specified by the element's type are zero. +We count tuple entries from left to right, starting with zero. +For example, `E<1>{}` has a 1 in position 1: `(0,1,0,...)`. +`E<3>{}` has a 1 in position 3: `(0,0,0,1,0,...)`. + +Basis elements can be *nested*. +For instance, in the above table, `E<0,1>{}` means that +in position 0 there is a `E<1>{}`: `((0,1,0,...),0,...)`. + +Basis elements can be *scaled*. +That is, they can be multiplied by an integer *scaling factor*. +For example, in `5*E<1>{}`, the scaling factor is `5`. +`5*E<1>{}` prints as `5@1` and means `(0,5,0,...)`. +The scaling factor commutes through any nesting. +For instance, `5*E<0,1>{}` prints as `5@1@0` +and means `((0,5,0,...),0,...)`. + +Basis elements can also be added together, +as long as their hierarchical structures are compatible. +For example, `3*E<0>{} + 4*E<1>{}` results in `(3,4,0,...)`. +Intuitively, "compatible" means that +the nested structure of the two basis elements +matches well enough to add the two elements together. + +### Linear combinations of strides Layouts work by taking the inner product -of their input coordinate with the strides. -For "ordinary" integer strides, e.g., `(1, 100)`, -the inner product of the input coordinate `(i, j)` -and the strides is `i + 100j`. -That gives the formula for the offset. -For strides built of basis functions, for example, -if the strides are `(_1@0, _1@1)`, -then the inner product of the input coordinate `(i, j)` -with the strides is `i@0 + j@1`. -The `i` here is a coefficient of the basis function `@0`, -and `j` is a coefficient of the basis function `@1`. -The result is a vector sum. We _interpret_ this result as -"the zeroth coefficient is i, and the first coefficient is j." -That translates into the (TMA) coordinate `(i, j)`. +of the natural coordinate with their strides. +For strides made of integer elements, e.g., `(1,100)`, +the inner product of the input coordinate `(i,j)` +and the stride is `i + 100j`. +Offsetting an "ordinary" tensor's pointer and this index +gives the pointer to the tensor element at `(i,j)`. + +For strides of basis elements, we still compute the inner product of the natural coordinate with the strides. +For example, if the stride is `(1@0,1@1)`, +then the inner product of the input coordinate `(i,j)` +with the strides is `i@0 + j@1 = (i,j)`. +That translates into the (TMA) coordinate `(i,j)`. If we wanted to reverse the coordinates, -then we could use `(_1@1, _1@0)` as the strides. -Evaluating the layout would give `i@1 + j@0`, -that is, `(j, i)`. +then we could use `(1@1,1@0)` as the stride. +Evaluating the layout would give `i@1 + j@0 = (j,i)`. + +A linear combination of basis elements +can be interpreted as a possibly multidimensional and hierarchical coordinate. +For instance, `2*2@1@0 + 3*1@1 + 4*5@1 + 7*1@0@0` +means `((0,2,...),0,...) + (0,3,0,...) + (0,20,0,...) + ((7,...),...) = ((7,2,...),23,...)` +and can be interpreted as the coordinate `((7,2),23)`. + +Thus, linear combinations of these strides can be used to generate TMA coordinates. +These coordinates, in turn, can be used to offset TMA coordinate iterators. + +## Application to TMA Tensors + +Now we can build CuTe Tensors like the one seen in the introduction. + +```cpp +Tensor a = make_tensor(make_inttuple_iter(0,0), + make_shape ( 4, 5), + make_stride(E<0>{}, E<1>{})); +print_tensor(a); + +Tensor b = make_tensor(make_inttuple_iter(0,0), + make_shape ( 4, 5), + make_stride(E<1>{}, E<0>{})); +print_tensor(b); +``` +prints +``` +ArithTuple(0,0) o (4,5):(_1@0,_1@1): + (0,0) (0,1) (0,2) (0,3) (0,4) + (1,0) (1,1) (1,2) (1,3) (1,4) + (2,0) (2,1) (2,2) (2,3) (2,4) + (3,0) (3,1) (3,2) (3,3) (3,4) + +ArithTuple(0,0) o (4,5):(_1@1,_1@0): + (0,0) (1,0) (2,0) (3,0) (4,0) + (0,1) (1,1) (2,1) (3,1) (4,1) + (0,2) (1,2) (2,2) (3,2) (4,2) + (0,3) (1,3) (2,3) (3,3) (4,3) +``` + + diff --git a/media/docs/cutlass_3x_backwards_compatibility.md b/media/docs/cutlass_3x_backwards_compatibility.md index 354e70dd..fa1df1fb 100644 --- a/media/docs/cutlass_3x_backwards_compatibility.md +++ b/media/docs/cutlass_3x_backwards_compatibility.md @@ -5,7 +5,7 @@ Although CUTLASS 3.0 restructures the GEMM hierarchy and introduces new types for the threadblock layer and below, we intend the entire source code to be usable in user applications. We expect users to be able to `#include` any source file from CUTLASS 3.0, whether -they implement the 2.x or the 3.x API, without breaking user builds. This means that a single +they implement the 2.x or the 3.x API, without breaking user builds. This means that a single translation unit should be able to contain any valid kernel regardless of its API version. The sections below discuss how `device` and `kernel` layer type names are made compatible across the two API versions, and what the users can expect out of the `threadblock` layer API going forward. @@ -126,7 +126,7 @@ a 2.x mainloop with a 3.0 collective epilogue. CUTLASS 3.x implements various embodiments of `kernel::GemmUniversal`. Each kernel layer schedule is specialized for a GEMM scheduling algorithm and GPU architecture. -Specializations of `kernel::GemmUniversal` for 3.0 APIs live in +Specializations of `kernel::GemmUniversal` for 3.0 APIs live in any of various `gemm_*.hpp` files in the directory [include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/). The specialization to which to dispatch is decided through the dispatch policy's `Schedule` type. @@ -155,7 +155,7 @@ All CUTLASS 3 `kernel::GemmUniversal` specializations expose the following (stat static bool can_implement(Arguments const& args); -// Returns a dim3 representing the threadblock shape. +// Returns a dim3 representing the threadblock shape. static dim3 get_block_shape(); @@ -172,7 +172,7 @@ the 3.x API or 2.x API: // include/cutlass/gemm/gemm.h namespace cutlass:gemm::detail { - + // The following metafunction is used to detect whether a // `kernel::Gemm` or `kernel::GemmUniversal` implements the CUTLASS 3.x API, // by checking whether the problem shape type is aliased within. @@ -193,7 +193,7 @@ from that of CUTLASS 2.x. With that also comes the introduction of the of the 2.x `cutlass::gemm::threadblock` layer. Going forward, CUTLASS 3.x will discontinue new developments in the following namespaces. -* `cutlass::*::threadblock::*` +* `cutlass::*::threadblock::*` * `cutlass::*::warp::*` * `cutlass::gemm::thread::*` * `cutlass::arch::*` (except `barrier.h`) @@ -274,7 +274,7 @@ that live in the header file [`cutlass/layout/matrix.h`](/include/cutlass/layout/matrix.h). The interpretation of these layouts in GEMM depends on whether they are applied -to the input matrix A or B. For the matrix A, "column major" means +to the input matrix A or B. For the matrix A, "column major" means that mode corresponding to M extent has stride 1, and "row major" means that mode corresponding to K extent has stride 1. This is the usual computer science definition @@ -332,7 +332,7 @@ and K mode as the 1st mode of the stride. ### Conversions between 2.x tags and 3.0 types Starting with CUTLASS 3.0, all layouts are described using -`cute::Shape` and `cute::Stride` which compose into a `cute::Layout`. +`cute::Shape` and `cute::Stride` which compose into a `cute::Layout`. In CUTLASS 2.x, various layout tags such as `cutlass::layout::RowMajor` are used to specialize template implementations. These tag types only encode information about the tensor strides, as 2.x layouts did not incorporate any concept of tensor shape in the layout tags themselves. @@ -415,18 +415,18 @@ Here is an excerpt. static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; // Warp shape is not a primary API type in 3.x, - // but we can best approximate it by inspecting the TiledMma::TiledShape_MNK. + // but we can best approximate it by inspecting the TiledMma // For this, we make the assumption that we always have 4 warps along M, // and the rest along N, with none along K. We also always round up // the warp count to 4 if the tiled mma is smaller than 128 threads. - static constexpr int WarpsInMma = std::max(4, cute::size(typename GemmKernel::TiledMma{}) / 32); + static constexpr int WarpsInMma = std::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); static constexpr int WarpsInMmaM = 4; static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); using WarpCount = cutlass::gemm::GemmShape; using WarpShape = cutlass::gemm::GemmShape< - cute::size<0>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaM, - cute::size<1>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{}) / WarpsInMmaN, - cute::size<2>(typename CollectiveMainloop::TiledMma::TiledShape_MNK{})>; + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; // Inspect TiledCopy for A and B to compute the alignment size static int constexpr kAlignmentA = gemm::detail::get_alignment_count_from_gmem_tiled_copy< @@ -435,7 +435,7 @@ Here is an excerpt. typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); ``` -CUTLASS's library and profiler use these reflective interfaces to +CUTLASS's library and profiler use these reflective interfaces to obtain the kernel's configuration parameters. Users can use these to approximate the CUTLASS 2.x types for 3.0 API kernels. However, the reflective interfaces cannot always match the types exactly, as the mappings are not always bijective. diff --git a/pyproject.toml b/pyproject.toml index 9b9224b2..3a91d394 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "3.3.0.0" +version = "3.4.0.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" diff --git a/python/README.md b/python/README.md index e84b7963..e453932a 100644 --- a/python/README.md +++ b/python/README.md @@ -14,7 +14,7 @@ import cutlass import numpy as np plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor) -A, B, C, D = [np.ones((4096, 4096), dtype=np.float16) for i in range(4)] +A, B, C, D = [np.ones((1024, 1024), dtype=np.float16) for i in range(4)] plan.run(A, B, C, D) ``` @@ -67,7 +67,7 @@ The CUTLASS Python interface currently supports the following operations: We recommend using the CUTLASS Python interface via an [NGC PyTorch Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch): ```bash -docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3 +docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3 -p 8888:8888 ``` The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8 and 3.9. @@ -99,6 +99,24 @@ If you would like to be able to make changes to CUTLASS Python interface and hav pip install -e . ``` +To test that your installation was successful, you can run: +```python +import cutlass +import numpy as np + +plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor) +A, B, C, D = [np.ones((128, 128), dtype=np.float16) for i in range(4)] +plan.run(A, B, C, D) +``` + +### Deep learning framework CUDA extensions +The CUTLASS Python interface provides utilities for exporting a CUTLASS kernel to a deep learning framework CUDA extensions. Currently, PyTorch CUDA extensions can be exported, but a similar pattern could be applied for other frameworks as well. An example of this is provided [here](/examples/python/02_pytorch_extension_grouped_gemm.ipynb). + +Currently, the following operations can be exported to a PyTorch CUDA extension: +* GEMM +* Grouped GEMM +* Conv2d + ### Examples Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python). diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index 0fd755ca..c919d6b7 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -75,6 +75,7 @@ from cutlass_library import ( DataType, EpilogueScheduleType, KernelScheduleType, + MathOperation, LayoutType, OpcodeClass, TileDescription, @@ -120,7 +121,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '3.3.0' +this.__version__ = '3.4.0' from cutlass.backend import create_memory_pool from cutlass.emit.pytorch import pytorch diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index 17954a93..430beb83 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -34,7 +34,8 @@ import ctypes from cutlass_library import ( DataType, - KernelScheduleType + KernelScheduleType, + TileSchedulerType ) from cutlass.backend.library import DataTypeSizeBytes @@ -99,6 +100,7 @@ class StrideBatched_(ctypes.Structure): ] + class GenericMainloopArguments3x_(ctypes.Structure): """ Structure representing the superset of possible mainloop arguments. @@ -115,6 +117,45 @@ class GenericMainloopArguments3x_(ctypes.Structure): ] +class _PersistentTileSchedulerArguments(ctypes.Structure): + _fields_ = [ + ("max_swizzle_size", ctypes.c_int), + ("raster_order_option", ctypes.c_int), + ] + + +class _PersistentTileSchedulerStreamKArguments(ctypes.Structure): + _fields_ = [ + ("splits", ctypes.c_int), + ("max_swizzle_size", ctypes.c_int), + ("raster_order_option", ctypes.c_int), + ("reduction_mode", ctypes.c_int), + ("decomposition_mode", ctypes.c_int), + ] + + +def get_tile_scheduler_arguments_3x( + tile_scheduler: TileSchedulerType, + splits: int = 1): + max_swizzle_size = 1 + raster_order_option = 0 # Heuristic + if tile_scheduler == TileSchedulerType.Persistent: + return _PersistentTileSchedulerArguments( + max_swizzle_size, + raster_order_option, + ) + elif tile_scheduler == TileSchedulerType.StreamK: + reduction_mode = 0 # Deterministic + decomposition_mode = 0 # Heuristic + return _PersistentTileSchedulerStreamKArguments( + splits, + max_swizzle_size, + raster_order_option, + reduction_mode, + decomposition_mode, + ) + + def get_mainloop_arguments_3x( kernel_schedule: KernelScheduleType, element_A, @@ -172,7 +213,7 @@ def get_mainloop_arguments_3x( return _MainloopArgumentsTma -def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor): +def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args): _EpilogueOutputOpParams = epilogue_functor.epilogue_type if hasattr(epilogue_functor, "visitor"): class _EpilogueArguments(ctypes.Structure): @@ -187,7 +228,6 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor): self.arg_C = epilogue_functor.arg_c_type(ptr_c) self.arg_D = epilogue_functor.arg_d_type(ptr_d) else: - class _EpilogueArguments(ctypes.Structure): _fields_ = [ ("epilogue", _EpilogueOutputOpParams), @@ -210,7 +250,7 @@ def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor): ("mainloop", mainloop_arguments), ("epilogue", _EpilogueArguments), ("hw_info", _HardwareInfo), - ("splits", ctypes.c_int) + ("scheduler", type(scheduler_args)), ] return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo diff --git a/python/cutlass/backend/evt/backend/sm90_nodes.py b/python/cutlass/backend/evt/backend/sm90_nodes.py index 4304c2cc..43b75328 100644 --- a/python/cutlass/backend/evt/backend/sm90_nodes.py +++ b/python/cutlass/backend/evt/backend/sm90_nodes.py @@ -87,7 +87,7 @@ class Sm90LoadSrcImpl(LoadSrcImpl): self._type_decl = f""" using ElementC = {DataTypeTag[self.element]}; using StrideC = {self.stride_mnl}; -using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch; +using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>; """ return self._type_decl diff --git a/python/cutlass/backend/evt/epilogue.py b/python/cutlass/backend/evt/epilogue.py index a49c1541..b555deb7 100644 --- a/python/cutlass/backend/evt/epilogue.py +++ b/python/cutlass/backend/evt/epilogue.py @@ -44,6 +44,7 @@ from cutlass.backend.epilogue import EpilogueFunctorBase import cutlass.backend.evt.backend from cutlass.backend.frontend import TensorFrontend from cutlass.utils.datatypes import is_numpy_tensor +from cutlass.backend.evt.passes.util import cc_map class EpilogueFunctorVisitor(EpilogueFunctorBase): @@ -56,7 +57,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase): """ def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None: # Type of Emitter based on CC - self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc}Emitter") + self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc_map[cc]}Emitter") # Visitor Types self.visitor = visitor diff --git a/python/cutlass/backend/evt/passes/pass_get_impl.py b/python/cutlass/backend/evt/passes/pass_get_impl.py index 90c74607..0e56eb7a 100644 --- a/python/cutlass/backend/evt/passes/pass_get_impl.py +++ b/python/cutlass/backend/evt/passes/pass_get_impl.py @@ -45,6 +45,7 @@ from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD from cutlass.backend.evt.passes.pass_manager import EVTPassBase from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass.backend.evt.passes.util import cc_map class PassGetImpl(EVTPassBase): @@ -82,8 +83,8 @@ class PassGetImpl(EVTPassBase): self.no_op_elimination() # Lower to cc-specific impl for node_meta in self.dag_ir.nodes_meta: - node_impl_ccs = getattr(evt_backend, f"sm{self.cc}_nodes") + node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes") node_meta.underlying_impl = getattr( node_impl_ccs, - f"Sm{self.cc}" + node_meta.underlying_impl.__class__.__name__ + f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__ )(node_meta) diff --git a/python/cutlass/backend/evt/passes/pass_manager.py b/python/cutlass/backend/evt/passes/pass_manager.py index 4fa31a8b..c2c50e60 100644 --- a/python/cutlass/backend/evt/passes/pass_manager.py +++ b/python/cutlass/backend/evt/passes/pass_manager.py @@ -39,6 +39,7 @@ from typing import Any import networkx as nx from cutlass.backend.evt.ir import DAGIR +from cutlass.backend.evt.passes.util import cc_map class EVTPassBase: @@ -102,7 +103,7 @@ class EVTPassBase: // sm80 specific method return """ - func_name = f"sm{self.cc}_{func.__name__}" + func_name = f"sm{cc_map[self.cc]}_{func.__name__}" if hasattr(self, func_name): return getattr(self, func_name) else: diff --git a/python/cutlass/backend/evt/passes/util.py b/python/cutlass/backend/evt/passes/util.py new file mode 100644 index 00000000..5e607fd1 --- /dev/null +++ b/python/cutlass/backend/evt/passes/util.py @@ -0,0 +1,43 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for passes +""" + +# Map from the CC of the kernel to the EVT implementation that the CC targets +cc_map = { + 80: 80, + 86: 80, + 89: 80, + 90: 90, +} diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index c5c756db..76543d76 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -82,7 +82,8 @@ from cutlass.backend.c_types import ( get_gemm_arguments_3x, get_gemm_arguments_streamk, get_gemm_grouped_arguments, - get_mainloop_arguments_3x + get_mainloop_arguments_3x, + get_tile_scheduler_arguments_3x, ) from cutlass.backend.library import ( ApiVersion, @@ -554,6 +555,7 @@ class GemmArguments3x(GemmArguments2x): mainloop, epilogue, hw_info, + self.operation.rt_module.scheduler_args ) return self.arguments @@ -1163,7 +1165,9 @@ extern "C" { operation.A.alignment, operation.B.alignment ) - self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor) + self.scheduler_args = get_tile_scheduler_arguments_3x(operation.tile_description.tile_scheduler) + self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x( + self.mainloop_args, operation.epilogue_functor, self.scheduler_args) def get_device_workspace_size(self, arguments: GemmArguments3x): return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments())) diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index ef1a8fce..1267498b 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -34,6 +34,7 @@ Classes containing valid operations for a given compute capability and data types. """ +from itertools import combinations_with_replacement import logging from cuda import __version__ @@ -60,6 +61,7 @@ class KernelsForDataType: def __init__(self, datatype_comb: tuple, layout_comb: tuple): self.datatype_comb = datatype_comb self.layout_comb = layout_comb + self.math_operations = set() # Dictionary mapping from alignment (int) to a list of kernels that fit the alignment # constraint for the data type combination @@ -73,6 +75,7 @@ class KernelsForDataType: if alignment_key not in self.kernels_by_alignment: self.kernels_by_alignment[alignment_key] = [] self.kernels_by_alignment[alignment_key].append(operation) + self.math_operations.add(operation.tile_description.math_instruction.math_operation) def alignments(self, operand: str): """ @@ -100,11 +103,14 @@ class KernelsForDataType: ops.extend(alignment_ops) return ops - def default_operation(self): + def default_operation(self, math_operation: cutlass.MathOperation): key = sorted(list(self.kernels_by_alignment.keys()))[0] - return self.kernels_by_alignment[key][0] + kernels = self.kernels_by_alignment[key] + if math_operation is not None: + kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation] + return kernels[0] - def operations(self, alignment_A: int, alignment_B: int, alignment_C: int): + def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass.MathOperation): """ Returns operations satisfying the alignment constraints @@ -114,6 +120,8 @@ class KernelsForDataType: :type alignment_B: int :param alignment_C: alignment constraint of operations to return :type alignment_C: int + :param math_operation: math operation to consider + :type math_operation: cutlass.MathOperation :return: list of operations :rtype: list @@ -126,13 +134,26 @@ class KernelsForDataType: min_alignment = min(alignment_A, alignment_B, alignment_C) key = f"{min_alignment} {min_alignment} {min_alignment}" if key not in self.kernels_by_alignment: - raise Exception( - f"No operations of alignment {og_key} found for data type and layout " - f"combination {self.datatype_comb} {self.layout_comb}. Tried to fall back " - f"to alignment {key}, but that was also not compatible. Compatible alignments " - f"are {self.kernels_by_alignment.keys()}" - ) - return self.kernels_by_alignment[key] + # Finally, go through all available alignment combinations and find + # one for which all values are less than those passed in. + key = None + alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) + for align_A, align_B, align_C in alignments: + if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C: + key = f"{align_A} {align_B} {align_C}" + break + + if key is None: + raise Exception( + f"No operations of alignment {og_key} found for data type and layout " + f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments " + f"are {self.kernels_by_alignment.keys()}" + ) + + ops = self.kernels_by_alignment[key] + if math_operation is not None: + ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation] + return ops def _operand_idx(self, key: str) -> int: operand_list = ["A", "B", "C"] @@ -187,6 +208,18 @@ class KernelsForDataType: for alignment in self.kernels_by_alignment.keys(): self.kernels_by_alignment[alignment].sort(key=key, reverse=True) + def supports_math_operation(self, math_operation: cutlass.MathOperation) -> bool: + """ + Returns whether `math_operation` is supported by at least one operation. + + :param math_operation: math operation to consider + :type math_operation: cutlass.MathOperation + + :return: whether math_operation is supported by at least one operation + :rtype: bool + """ + return math_operation is None or math_operation in self.math_operations + class ArchOptions: """ @@ -213,7 +246,8 @@ class ArchOptions: allowed_math_operations: list = [ cutlass_library.MathOperation.multiply_add, cutlass_library.MathOperation.multiply_add_saturate, - cutlass_library.MathOperation.multiply_add_mixed_input_upcast + cutlass_library.MathOperation.multiply_add_mixed_input_upcast, + cutlass_library.MathOperation.multiply_add_fast_f32 ] ): self.cc = kernel_cc @@ -270,8 +304,6 @@ class ArchOptions: if mi.math_operation not in self.allowed_math_operations: continue - datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator) - # Prune operations that don't fit in shared memory td = td_from_profiler_op(op) if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]: @@ -323,6 +355,15 @@ class ArchOptions: (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), ] + # Add FP8 A/B/C + fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2] + for type_comb in combinations_with_replacement(fp8_types, 3): + types.append(type_comb) + + # Add FP8 A/B with FP32 C + for type_comb in combinations_with_replacement(fp8_types, 2): + types.append(type_comb + (cutlass.DataType.f32,)) + layouts = [ (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor), (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor), @@ -395,7 +436,7 @@ class ArchOptions: self.operations_by_opclass[oc][comb].sort() def opclass_supports_combination( - self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple + self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation ) -> bool: """ Returns whether the provided operation class supports the provided data type and layout combination @@ -406,6 +447,8 @@ class ArchOptions: :type datatype_comb: tuple[cutlass_library.DataType] :param layout_comb: tuple of data types for (layout_A, layout_B) :type layout_comb: tuple[cutlass_library.LayoutType] + :param math_operation: math operation to consider or None if any can be considered + :type math_operation: cutlass.MathOperation :return: set of operation classes that support the provided data type and layout combination :rtype: set @@ -413,7 +456,14 @@ class ArchOptions: if op_class not in self.operations_by_opclass: raise Exception(f"Unexpected or unsupported operation class {op_class}") - return (datatype_comb, layout_comb) in self.operations_by_opclass[op_class] + if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)): + if math_operation is not None: + return operations.supports_math_operation(math_operation) + else: + return True + + return False + def supporting_opclasses( self, @@ -422,6 +472,7 @@ class ArchOptions: element_accumulator: cutlass_library.DataType, layout_a: cutlass_library.LayoutType, layout_b: cutlass_library.LayoutType, + math_operation: cutlass_library.MathOperation, ) -> set: """ Returns a set of operation classes that support the provided data type combination @@ -436,6 +487,8 @@ class ArchOptions: :type layout_a: cutlass_library.LayoutType :param layout_b: layout of operand B :type layout_b: cutlass_library.LayoutType + :param math_operation: math operation to consider + :type math_operation: cutlass.MathOperation :return: set of operation classes that support the provided data type combination :rtype: set @@ -445,7 +498,7 @@ class ArchOptions: layout_comb = (layout_a, layout_b) for op_class in self.operations_by_opclass.keys(): - if self.opclass_supports_combination(op_class, datatype_comb, layout_comb): + if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): supporting_op_classes.add(op_class) return supporting_op_classes @@ -457,6 +510,7 @@ class ArchOptions: element_accumulator: cutlass_library.DataType, layout_a: cutlass_library.LayoutType, layout_b: cutlass_library.LayoutType, + math_operation: cutlass_library.MathOperation, ) -> KernelsForDataType: """ Returns whether the provided operation class supports the provided data type combination @@ -473,13 +527,15 @@ class ArchOptions: :type layout_a: cutlass_library.LayoutType :param layout_b: layout of operand B :type layout_b: cutlass_library.LayoutType + :param math_operation: math operation to consider + :type math_operation: cutlass.MathOperation :return: container of kernels by alignment supported by the provided combination of parameters :rtype: KernelsForDataType """ datatype_comb = (element_a, element_b, element_accumulator) layout_comb = (layout_a, layout_b) - if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb): + if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): raise Exception( f"Data type layout combination {datatype_comb}, {layout_comb} " f"is not supported by opcode class {op_class} on CC {self.cc}." diff --git a/python/cutlass/op/conv.py b/python/cutlass/op/conv.py index d7cd90ad..3b8545fa 100644 --- a/python/cutlass/op/conv.py +++ b/python/cutlass/op/conv.py @@ -293,7 +293,7 @@ class Conv2d(OperationBase): self.possible_op_classes = self.options.supporting_opclasses( self._element_a, self._element_b, self._element_accumulator, - self._layout_a, self._layout_b + self._layout_a, self._layout_b, self._math_operation ) if cutlass.OpcodeClass.TensorOp in self.possible_op_classes: @@ -301,8 +301,13 @@ class Conv2d(OperationBase): elif cutlass.OpcodeClass.Simt in self.possible_op_classes: self.opclass = cutlass.OpcodeClass.Simt else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + raise Exception(f'No kernel configuration found for supported data type and layout ' - f'combination {datatype_comb}x{layout_comb}') + f'combination {datatype_comb}x{layout_comb}{math_op_str}') if reset_epilogue: self._reset_epilogue_functor_activation(epilogue.identity) @@ -345,7 +350,7 @@ class Conv2d(OperationBase): return if isinstance(td, dict): if self._tile_description is None: - op = self.possible_operations.default_operation() + op = self.possible_operations.default_operation(self._math_operation) self._tile_description = datatypes.td_from_profiler_op(op) if "cluster_shape" in td.keys(): if td["cluster_shape"] != [1, 1, 1]: @@ -397,6 +402,11 @@ class Conv2d(OperationBase): description_str = [] for op in self.possible_operations.all_operations: td = datatypes.td_from_profiler_op(op) + + if self._math_operation is not None: + if td.math_instruction.math_operation != self._math_operation: + continue + if str(td) not in description_str: description_str.append(str(td)) descriptions.append(td) @@ -569,7 +579,7 @@ class Conv2d(OperationBase): if self.tile_description is not None: tile_description = self.tile_description else: - op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0] + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] tile_description = datatypes.td_from_profiler_op(op) else: valid, err_str = self._valid_tile_description(tile_description) diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index 718696f1..72afcba5 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -202,14 +202,14 @@ class Gemm(OperationBase): :type element_C: cutlass.DataType :param element_D: data type to be used for operand D :type element_D: cutlass.DataType - :type layout_A: layout of operand A - :param layout_A: cutlass.LayoutType - :type layout_B: layout of operand B - :param layout_B: cutlass.LayoutType - :type layout_C: layout of operand C - :param layout_C: cutlass.LayoutType - :type layout_D: layout of operand D - :param layout_D: cutlass.LayoutType + :param layout_A: layout of operand A + :type layout_A: cutlass.LayoutType + :param layout_B: layout of operand B + :type layout_B: cutlass.LayoutType + :param layout_C: layout of operand C + :type layout_C: cutlass.LayoutType + :param layout_D: layout of operand D + :type layout_D: cutlass.LayoutType """ def __init__( @@ -281,17 +281,23 @@ class Gemm(OperationBase): # Set the default op class datatype_comb = (self._element_a, self._element_b, self._element_accumulator) layout_comb = (self._layout_a, self._layout_b) + self.possible_op_classes = self.options.supporting_opclasses( self._element_a, self._element_b, self._element_accumulator, - self._layout_a, self._layout_b) + self._layout_a, self._layout_b, self._math_operation) if cutlass.OpcodeClass.TensorOp in self.possible_op_classes: self.opclass = cutlass.OpcodeClass.TensorOp elif cutlass.OpcodeClass.Simt in self.possible_op_classes: self.opclass = cutlass.OpcodeClass.Simt else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + raise Exception(f'No kernel configuration found for supported data type and layout ' - f'combination {datatype_comb}x{layout_comb}') + f'combination {datatype_comb}x{layout_comb}{math_op_str}') if reset_epilogue: self._reset_epilogue_functor_activation(cutlass.epilogue.identity) @@ -349,7 +355,7 @@ class Gemm(OperationBase): return if isinstance(td, dict): if self._tile_description is None: - op = self.possible_operations.default_operation() + op = self.possible_operations.default_operation(self._math_operation) self._tile_description = datatypes.td_from_profiler_op(op) td = self._tile_description.clone_and_update(td) @@ -394,7 +400,10 @@ class Gemm(OperationBase): :returns: list of valid tile descriptions for the operations :rtype: list """ - return [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations] + tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations] + if self._math_operation is not None: + tds = [td for td in tds if td.tile_description.math_instruction == self._math_operation] + return tds def construct( self, tile_description: TileDescription = None, @@ -423,18 +432,19 @@ class Gemm(OperationBase): tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A) tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) - alignment_pref_C = max(self.possible_operations.alignments("C")) - if self._element_c != DataType.void: - alignment_pref_C = min(128 // DataTypeSize[self._element_c], alignment_pref_C) - - alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C) - tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) - self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + if alignment_C is None: + alignment_C = max(self.possible_operations.alignments("C")) + if self._element_c != DataType.void: + alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C) if tile_description is None: if self._tile_description is None: - op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0] + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] tile_description = datatypes.td_from_profiler_op(op) + + # The selected op may have lower alignment than that determined above, so we must + # reset alignment here. + alignment_C = op.C.alignment else: tile_description = self._tile_description else: @@ -443,6 +453,9 @@ class Gemm(OperationBase): raise Exception(f"Invalid tile description. {err_str}") self._tile_description = tile_description + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + operation = GemmOperationUniversal( arch=self.current_cc, tile_description=tile_description, @@ -599,9 +612,13 @@ class Gemm(OperationBase): """ dtype, layout = datatypes.get_datatype_and_layout(tensor) if dtype != ref_type or layout != ref_layout: - raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' - f'does not match the expected type and ' - f'layout of ({ref_type}, {ref_layout}).') + try: + # Attempt to transpose the tensor to fit the desired layout + tensor = tensor.transpose(-1, -2) + except: + raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' + f'does not match the expected type and ' + f'layout of ({ref_type}, {ref_layout}) and transpose failed.') def run(self, A=None, B=None, C=None, D=None, alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments: diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py index d20ac507..f88dea17 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass/op/gemm_grouped.py @@ -174,7 +174,7 @@ class GroupedGemm(Gemm): tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) if tile_description is None: - op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0] + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] tile_description = datatypes.td_from_profiler_op(op) else: valid, err_str = self._valid_tile_description(tile_description) diff --git a/python/cutlass/op/op.py b/python/cutlass/op/op.py index d0630d67..6ed35bf9 100644 --- a/python/cutlass/op/op.py +++ b/python/cutlass/op/op.py @@ -36,7 +36,13 @@ Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv from bisect import bisect_left -from cutlass_library import DataType, DataTypeSize, OperationKind, SharedMemPerCC +from cutlass_library import ( + DataType, + DataTypeSize, + MathOperation, + OperationKind, + SharedMemPerCC +) import cutlass from cutlass import get_option_registry @@ -67,6 +73,7 @@ class OperationBase: self.specified_kernel_cc = kernel_cc is not None self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) self.tile_description = None + self._math_operation = None self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind) @@ -197,14 +204,10 @@ class OperationBase: self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) return tensor - # - # Opcode Related - # - @property def opclass(self) -> cutlass.OpcodeClass: """ - Returns the opcode class currently in use by the GEMM + Returns the opcode class currently in use :return: opcode class currently in use :rtype: cutlass.OpcodeClass @@ -226,15 +229,41 @@ class OperationBase: # Changing the op class also changes the possible operations available. Reset these. self.possible_operations = self.options.operations( self.op_class, self._element_a, self._element_b, - self._element_accumulator, self._layout_a, self._layout_b) + self._element_accumulator, self._layout_a, self._layout_b, self._math_operation) # Changing the op class changes the elements per access in the epilogue. Reset this. if self.epilogue_functor is not None: self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor) - # - # Epilogue - # + @property + def math_operation(self) -> cutlass.MathOperation: + """ + Returns the math operation currently in use + + :return: math operation currently in use + :rtype: cutlass.MathOperation + """ + return self._math_operation + + @math_operation.setter + def math_operation(self, mo: cutlass.MathOperation): + if isinstance(mo, str): + mo = datatypes.getattr_enum(cutlass.MathOperation, mo) + + if not self.specified_kernel_cc: + if self.current_cc == 90: + # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif self.current_cc == 90: + raise Exception("CUTLASS 3.0 kernels do not use different math operations. " + "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`" + "parameter when constructing the plan.") + + self._math_operation = mo + self._reset_operations() def _elements_per_access(self): if self.op_class == cutlass.OpcodeClass.Simt: @@ -262,7 +291,7 @@ class OperationBase: raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") self._reset_options(80) self._reset_operations(reset_epilogue=False) - elif (self.cc == 90 and self.current_cc != 90 and activation == identity): + elif (self.cc == 90 and self.current_cc != 90 and activation == identity and self._math_operation is None): # SM80 fallback kernels are currently used. Since an identity activation is requested, # we can switch back to using SM90 kernels. self._reset_options(90) @@ -272,7 +301,7 @@ class OperationBase: raise Exception("Epilogues with elementwise fusion are not currently supported " "in the Python interface for 3.x kernels. To use 2.x kernels " "with fused elementwise epilogues, do not set the `kernel_cc` " - "parameter when constructing the Gemm object.") + "parameter when constructing the plan.") return get_activation_epilogue( activation, @@ -364,6 +393,7 @@ class OperationBase: # The shared memory is only a concern for sm90 epilogue # In sm80, the epilogue and mainloop share the shared memory return + datatype_comb = self.possible_operations.datatype_comb layout_comb = self.possible_operations.layout_comb new_possible_operations = KernelsForDataType(datatype_comb, layout_comb) diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py index d26ada29..a4f90d36 100644 --- a/python/cutlass/utils/datatypes.py +++ b/python/cutlass/utils/datatypes.py @@ -176,6 +176,17 @@ def is_torch_available(): cutlass.DataType.s32: torch.int32, cutlass.DataType.u8: torch.uint8, } + + def possibly_add_type(torch_type_name, cutlass_type): + # Only try adding the type if the version of torch being used supports it + if hasattr(torch, torch_type_name): + torch_type = getattr(torch, torch_type_name) + _torch_to_library_dict[torch_type] = cutlass_type + _library_to_torch_dict[cutlass_type] = torch_type + + possibly_add_type("float8_e4m3fn", cutlass.DataType.e4m3) + possibly_add_type("float8_e5m2", cutlass.DataType.e5m2) + except ImportError: torch_available = False _torch_to_library_dict = {} diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 11691f42..c8d1de0b 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -61,7 +61,8 @@ class GemmOperation: def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, - tile_scheduler = TileSchedulerType.Default, extra_args = None): + tile_scheduler = TileSchedulerType.Default + ): kinds_3x = { GemmKind.Universal3x, @@ -88,6 +89,10 @@ class GemmOperation: self.epilogue_schedule = epilogue_schedule self.element_epilogue = element_epilogue self.epilogue_functor = epilogue_functor + + if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination: + self.epilogue_functor = EpilogueFunctor3x.LinearCombination + self.swizzling_functor = swizzling_functor self.tile_scheduler = tile_scheduler @@ -709,9 +714,9 @@ class EmitGemmUniversal3xInstance: ] self.builtin_epilogue_functor_template = """ ${epilogue_functor}< + ${element_d}, + ${element_epilogue}, ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, ${element_epilogue} > """ @@ -726,7 +731,8 @@ using ${operation_name}_epilogue = ${element_accumulator}, ${element_epilogue}, ${element_c}, ${layout_c}, ${align_c}, ${element_d}, ${layout_d}, ${align_d}, - ${epilogue_schedule} + ${epilogue_schedule}, + ${epilogue_functor} >::CollectiveOp; using ${operation_name}_mainloop = @@ -757,9 +763,11 @@ struct ${operation_name} : def instance_template(self): return """ ${compile_guard_start} - using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; - manifest.append( - new ${gemm_kind}("${operation_name}")); + { + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + } ${compile_guard_end} """ @@ -788,9 +796,8 @@ ${compile_guard_end} # Support built-in epilogue functors or user-defined functions if isinstance(operation.epilogue_functor, enum.Enum): values = { - 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor], } epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) else: @@ -799,6 +806,9 @@ ${compile_guard_end} element_a = DataTypeTag[operation.A.element] element_b = DataTypeTag[operation.B.element] epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + element_a = DataTypeTag[operation.A.element] + element_b = DataTypeTag[operation.B.element] + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] values = { 'operation_name': operation.procedural_name(), 'operation_suffix': self.operation_suffix, diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 1f07e76b..1cc61b42 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -192,14 +192,14 @@ def CreateGemmUniversal3xOperator( C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) - extra_args = {} + gemm_op_extra_args = {} gemm_kind = GemmKind.Universal3x element_compute = data_type.get("epi_type", data_type["acc_type"]) operation = GemmOperation( gemm_kind, tile_description.minimum_compute_capability, tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, - kernel_schedule, epilogue_schedule, tile_scheduler, extra_args) + kernel_schedule, epilogue_schedule, tile_scheduler, **gemm_op_extra_args) manifest.append(operation) operations.append(operation) diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index c0c425c2..d288d43e 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -466,6 +466,13 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', } +class EpilogueFunctor3x(enum.Enum): + LinearCombination = enum_auto() +# +EpilogueFunctor3xTag = { + EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', +} + class TileSchedulerType(enum.Enum): Default = enum_auto() Persistent = enum_auto() diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index 6efc7659..4a0fa3e2 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -429,7 +429,7 @@ class Manifest: self.kernel_filter_list = [] else: self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) - _LOGGER.info("Using {filter_count} kernel filters from {filter_file}".format( + _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format( filter_count = len(self.kernel_filter_list), filter_file = args.kernel_filter_file)) diff --git a/python/pycute/layout.py b/python/pycute/layout.py index c2ad2d11..de823fe1 100644 --- a/python/pycute/layout.py +++ b/python/pycute/layout.py @@ -101,7 +101,7 @@ class Layout(LayoutBase): # cosize(layout) Size of the codomain def cosize(self): - return tuple_max(tuple((1, elem_scale(self.shape, self.stride)))) + return self(self.size() - 1) + 1 # print and str def __str__(self): diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index cf57b223..7e78a218 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass', - version='3.3.0', + version='3.4.0', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_library.py b/python/setup_library.py index 17905f40..115e6c0a 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='cutlass_library', - version='3.3.0', + version='3.4.0', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 316dbc88..bf06967d 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='3.3.0', + version='3.4.0', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/setup.cfg b/setup.cfg index c996eed4..a4216c2b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = nvidia-cutlass -version = 3.3.0.0 +version = 3.4.0.0 [options] packages = diff --git a/test/python/cutlass/evt/evt_compute_sm80_90.py b/test/python/cutlass/evt/evt_compute_sm80_90.py index e79a2822..be6f1af3 100644 --- a/test/python/cutlass/evt/evt_compute_sm80_90.py +++ b/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase cutlass.set_log_level(logging.WARNING) -@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") -class TestEVTComputeSM90(EVTTestCaseBase): +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTCompute(EVTTestCaseBase): def test_arith(self): """ diff --git a/test/python/cutlass/evt/evt_layout_sm80_90.py b/test/python/cutlass/evt/evt_layout_sm80_90.py index 3cbc9530..dba40003 100644 --- a/test/python/cutlass/evt/evt_layout_sm80_90.py +++ b/test/python/cutlass/evt/evt_layout_sm80_90.py @@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase cutlass.set_log_level(logging.WARNING) -@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") -class TestEVTLayoutSM90(EVTTestCaseBase): +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLayout(EVTTestCaseBase): def test_permute_1(self): """ @@ -74,7 +74,7 @@ class TestEVTLayoutSM90(EVTTestCaseBase): result_keys = ["D", "F"] launcher.verify((m, n, k), input_keys, result_keys, l) - @unittest.skipIf(device_cc() == 80, "This unittest is for cc = Sm90 only") + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") def test_permute_2(self): """ Returning a tensor with shape [m, n] @@ -99,7 +99,7 @@ class TestEVTLayoutSM90(EVTTestCaseBase): result_keys = ["D", "F"] launcher.verify((m, n, k), input_keys, result_keys, l) - @unittest.skipIf(device_cc() == 80, "This unittest is for cc = Sm90 only") + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") def test_permute_3(self): """ Returning a tensor with shape [m, n] diff --git a/test/python/cutlass/evt/evt_load_sm80_90.py b/test/python/cutlass/evt/evt_load_sm80_90.py index 57a8ed0d..2758de93 100644 --- a/test/python/cutlass/evt/evt_load_sm80_90.py +++ b/test/python/cutlass/evt/evt_load_sm80_90.py @@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase cutlass.set_log_level(logging.WARNING) -@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") -class TestEVTLoadSM90(EVTTestCaseBase): +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLoad(EVTTestCaseBase): def test_tensor_load(self): """ diff --git a/test/python/cutlass/evt/evt_mixed_sm80_90.py b/test/python/cutlass/evt/evt_mixed_sm80_90.py index 82fd75bb..343c3d26 100644 --- a/test/python/cutlass/evt/evt_mixed_sm80_90.py +++ b/test/python/cutlass/evt/evt_mixed_sm80_90.py @@ -47,8 +47,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase cutlass.set_log_level(logging.WARNING) -@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") -class TestEVTMixedSM90(EVTTestCaseBase): +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTMixed(EVTTestCaseBase): def test_mixed_dag(self): def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): F = alpha * accum + (beta * C + aux) @@ -84,7 +84,7 @@ class TestEVTMixedSM90(EVTTestCaseBase): result_keys = ["D", "F", "F_row_max", "E_col_max"] launcher.verify((m, n, k), input_keys, result_keys, l) - @unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only") + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") def test_mixed_dag_float(self): def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): F = alpha * accum + (beta * C + aux) @@ -114,7 +114,7 @@ class TestEVTMixedSM90(EVTTestCaseBase): result_keys = ["D", "F", "F_row_max", "E_col_max"] launcher.verify((m, n, k), input_keys, result_keys, l) - @unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only") + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") def test_mixed_dag_stage2(self): def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): F = alpha * accum + (beta * C + aux) @@ -144,7 +144,7 @@ class TestEVTMixedSM90(EVTTestCaseBase): result_keys = ["D", "F", "F_row_max", "E_col_max"] launcher.verify((m, n, k), input_keys, result_keys, l) - @unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only") + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") def test_mixed_dag_partition_k(self): def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): F = alpha * accum + (beta * C + aux) @@ -179,7 +179,7 @@ class TestEVTMixedSM90(EVTTestCaseBase): result_keys = ["D", "F", "F_row_max", "E_col_max"] launcher.verify((m, n, k), input_keys, result_keys, l) - @unittest.skipIf(device_cc() != 80, "This unittest is for cc = Sm80 only") + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") def test_mixed_dag_stream_k(self): def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): F = alpha * accum + (beta * C + aux) diff --git a/test/python/cutlass/evt/evt_store_sm80_90.py b/test/python/cutlass/evt/evt_store_sm80_90.py index 7046d4d0..859cc810 100644 --- a/test/python/cutlass/evt/evt_store_sm80_90.py +++ b/test/python/cutlass/evt/evt_store_sm80_90.py @@ -46,8 +46,8 @@ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase cutlass.set_log_level(logging.WARNING) -@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") -class TestEVTStoreSM90(EVTTestCaseBase): +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTStore(EVTTestCaseBase): def test_aux_store(self): """ diff --git a/test/python/cutlass/gemm/gemm_f16_sm80.py b/test/python/cutlass/gemm/gemm_f16_sm80.py index 02de6da4..4de34fa5 100644 --- a/test/python/cutlass/gemm/gemm_f16_sm80.py +++ b/test/python/cutlass/gemm/gemm_f16_sm80.py @@ -46,8 +46,10 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 80 +dtype = cutlass.DataType.f16 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF16Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -56,13 +58,14 @@ class GemmF16Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF16Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ """ pass -add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1]) +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) diff --git a/test/python/cutlass/gemm/gemm_f16_sm90.py b/test/python/cutlass/gemm/gemm_f16_sm90.py index 0e8fe945..b521ce89 100644 --- a/test/python/cutlass/gemm/gemm_f16_sm90.py +++ b/test/python/cutlass/gemm/gemm_f16_sm90.py @@ -46,8 +46,10 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 90 +dtype = cutlass.DataType.f16 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF16Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -55,7 +57,7 @@ class GemmF16Sm90(unittest.TestCase): pass -add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=cutlass.DataType.f16, +add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype, warp_count=None, compilation_modes=['nvcc']) add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) diff --git a/test/python/cutlass/gemm/gemm_f32_sm80.py b/test/python/cutlass/gemm/gemm_f32_sm80.py index 903965a3..12c3259a 100644 --- a/test/python/cutlass/gemm/gemm_f32_sm80.py +++ b/test/python/cutlass/gemm/gemm_f32_sm80.py @@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 80 +dtype = cutlass.DataType.f32 + @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF32Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -56,6 +59,7 @@ class GemmF32Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF32Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -63,37 +67,37 @@ class GemmF32Sm80StreamK(unittest.TestCase): pass -add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f32, cc=cc, cluster_shape=[1, 1, 1]) +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) # Tests using SIMT add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_f64_sm80.py b/test/python/cutlass/gemm/gemm_f64_sm80.py index 67049b94..fb4439f7 100644 --- a/test/python/cutlass/gemm/gemm_f64_sm80.py +++ b/test/python/cutlass/gemm/gemm_f64_sm80.py @@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 80 +dtype = cutlass.DataType.f64 + @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF64Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -56,6 +59,7 @@ class GemmF64Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF64Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -63,36 +67,36 @@ class GemmF64Sm80StreamK(unittest.TestCase): pass -add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f64, cc=cc, cluster_shape=[1, 1, 1]) +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) # Tests using SIMT add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) +add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_f64_sm90.py b/test/python/cutlass/gemm/gemm_f64_sm90.py index a3145a4b..b4b45827 100644 --- a/test/python/cutlass/gemm/gemm_f64_sm90.py +++ b/test/python/cutlass/gemm/gemm_f64_sm90.py @@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 90 +dtype = cutlass.DataType.f64 + @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF64Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -56,8 +59,7 @@ class GemmF64Sm90(unittest.TestCase): add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], - element=cutlass.DataType.f64, element_output=cutlass.DataType.f64, - element_accumulator=cutlass.DataType.f64, compilation_modes=['nvcc']) + element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc']) add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) diff --git a/test/python/cutlass/gemm/gemm_f8_sm90.py b/test/python/cutlass/gemm/gemm_f8_sm90.py new file mode 100644 index 00000000..45b12423 --- /dev/null +++ b/test/python/cutlass/gemm/gemm_f8_sm90.py @@ -0,0 +1,112 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass +from cutlass.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass.DataType.e4m3 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E4M3Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) + +# Test with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, + element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, + element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, + element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, + element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.e4m3, + element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +# +# Add a test for E5M2 +# +dtype = cutlass.DataType.e5m2 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E5M2Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype, + element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/cutlass/gemm/gemm_mixed_sm80.py b/test/python/cutlass/gemm/gemm_mixed_sm80.py index 152f8eb4..e88019e4 100644 --- a/test/python/cutlass/gemm/gemm_mixed_sm80.py +++ b/test/python/cutlass/gemm/gemm_mixed_sm80.py @@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 80 +dtype =cutlass.DataType.f16 + @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmMixedSm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -55,7 +58,7 @@ class GemmMixedSm80(unittest.TestCase): pass -add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1], +add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1], opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32) diff --git a/test/python/cutlass/gemm/gemm_s8_sm80.py b/test/python/cutlass/gemm/gemm_s8_sm80.py index d9b929f9..20ebe831 100644 --- a/test/python/cutlass/gemm/gemm_s8_sm80.py +++ b/test/python/cutlass/gemm/gemm_s8_sm80.py @@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 80 +dtype = cutlass.DataType.s8 + @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmS8Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -56,6 +59,7 @@ class GemmS8Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmS8Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -63,7 +67,7 @@ class GemmS8Sm80StreamK(unittest.TestCase): pass -add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.s8, cc=cc, cluster_shape=[1, 1, 1]) +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) diff --git a/test/python/cutlass/gemm/gemm_s8_sm90.py b/test/python/cutlass/gemm/gemm_s8_sm90.py index aafa9fd2..e1d7be64 100644 --- a/test/python/cutlass/gemm/gemm_s8_sm90.py +++ b/test/python/cutlass/gemm/gemm_s8_sm90.py @@ -46,8 +46,11 @@ from utils import LayoutCombination, add_test_gemm cutlass.set_log_level(logging.WARNING) cc = 90 +dtype = cutlass.DataType.s8 + @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmS8Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -55,7 +58,7 @@ class GemmS8Sm90(unittest.TestCase): pass -add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=cutlass.DataType.s8, compilation_modes=['nvcc']) +add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc']) add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) diff --git a/test/python/cutlass/gemm/gemm_testbed.py b/test/python/cutlass/gemm/gemm_testbed.py index 2507cd75..1594a837 100644 --- a/test/python/cutlass/gemm/gemm_testbed.py +++ b/test/python/cutlass/gemm/gemm_testbed.py @@ -128,13 +128,22 @@ class GemmUniversalLauncher: def uniform_init(self, shape, dtype, layout): size = prod(shape) if dtype.is_floating_point: - data = torch.ceil(torch.empty(size=(size,), dtype=dtype, device="cuda").uniform_(self.rand_min - 0.5, self.rand_max - 0.5)) + # Initialize data in FP32 and call convert to the data type we desire. + # This is a workaround for the following error that occurs when attempting to + # call uniform_ on a tensor with torch.float8_e4m3fn data: + # RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn' + data = torch.ceil( + torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_( + self.rand_min - 0.5, self.rand_max - 0.5) + ).to(dtype) else: # PyTorch does not currently support integer-typed matrix multiplications on GPU. # Fall back to CPU for integer type references. data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1) - if dtype == torch.float64 or dtype == torch.float32: + is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1) + + if dtype == torch.float64 or dtype == torch.float32 or is_fp8: data = data.to("cpu") data_ref = data.reshape(shape) @@ -145,6 +154,12 @@ class GemmUniversalLauncher: data_cutlass = data_ref.transpose(-1, -2).contiguous() data_cutlass = data_cutlass.to("cuda") + + # As of this writing, few operations in PyTorch are supported with FP8 data. + # Thus, we perform computation in FP32 for FP8 reference checks. + if is_fp8: + data_ref = data_ref.to(torch.float32) + return data_cutlass, data_ref def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index d102667f..5ca7cfe5 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -65,10 +65,10 @@ function(cutlass_test_unit_add_executable NAME) set(options WITHOUT_CUDA) set(oneValueArgs) - set(multiValueArgs) + set(multiValueArgs TEST_SETS_SUPPORTED EXTRA_INCLUDE_DIRS) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS}) + cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS} BATCH_SOURCES OFF) target_compile_definitions(${NAME} PUBLIC CUTLASS_TARGET_NAME="${NAME}") @@ -76,6 +76,7 @@ function(cutlass_test_unit_add_executable NAME) ${NAME} PRIVATE ${CUTLASS_UNIT_TEST_COMMON_DIR} + ${__EXTRA_INCLUDE_DIRS} ) if (__WITHOUT_CUDA) # Avoid CUDA dependencies for host-only unit tests that provide the @@ -110,12 +111,12 @@ function(cutlass_test_unit_add_executable NAME) cutlass_add_executable_tests( ${NAME_STEM} ${NAME} + TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED} TEST_COMMAND_OPTIONS CUTLASS_TEST_UNIT_TEST_COMMAND_OPTIONS ${RESULT_CACHE_FILE_ARGS} ) endfunction() - add_custom_target(cutlass_test_unit) add_custom_target(test_unit) diff --git a/test/unit/conv/device/cache_testbed_output.h b/test/unit/conv/cache_testbed_output.h similarity index 96% rename from test/unit/conv/device/cache_testbed_output.h rename to test/unit/conv/cache_testbed_output.h index 29be4346..a8f29b47 100644 --- a/test/unit/conv/device/cache_testbed_output.h +++ b/test/unit/conv/cache_testbed_output.h @@ -48,15 +48,15 @@ #include "cutlass/core_io.h" #include "cutlass/util/tensor_view_io.h" +#include "thrust/universal_vector.h" + #ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS #define CUTLASS_TEST_ENABLE_CACHED_RESULTS false #endif ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace test { -namespace conv { -namespace device { +namespace test::conv::device { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -325,7 +325,6 @@ inline std::ostream &EncodeProblemSize( } ///////////////////////////////////////////////////////////////////////////////////////////////// - template inline std::string ElementTypeName() { return std::string(typeid(Element).name()); @@ -452,6 +451,12 @@ inline std::string TensorTypeName() { return ss.str(); } +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName(); + return ss.str(); +} ///////////////////////////////////////////////////////////////////////////////////////////////// /// Hash function on a byte array @@ -511,6 +516,16 @@ uint32_t TensorHash( return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); } +template +uint32_t TensorHash( + thrust::universal_vector& tensor, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits::value / 8, crc); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// template < @@ -533,6 +548,23 @@ inline std::ostream &EncodeTypes( return out; } +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName(); + + return out; +} ///////////////////////////////////////////////////////////////////////////////////////////////// template < @@ -790,8 +822,6 @@ inline CachedTestKey CreateCachedConv3dTestKey( ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // nammespace conv -} // namespace test +} // namespace test::conv::device ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 61f6ff73..9e6db4c5 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -55,7 +55,7 @@ #include "cutlass/core_io.h" #include "cutlass/util/tensor_view_io.h" -#include "cache_testbed_output.h" +#include "../cache_testbed_output.h" namespace test { namespace conv { diff --git a/test/unit/conv/device/conv2d_testbed_interleaved.h b/test/unit/conv/device/conv2d_testbed_interleaved.h index fe57ec85..8c1512f1 100644 --- a/test/unit/conv/device/conv2d_testbed_interleaved.h +++ b/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -56,7 +56,7 @@ #include "cutlass/core_io.h" #include "cutlass/util/tensor_view_io.h" -#include "cache_testbed_output.h" +#include "../cache_testbed_output.h" namespace test { namespace conv { diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index d1a1e666..0c73519e 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -59,7 +59,7 @@ #include "cutlass/core_io.h" #include "cutlass/util/tensor_view_io.h" -#include "cache_testbed_output.h" +#include "../cache_testbed_output.h" namespace test { namespace conv { diff --git a/test/unit/conv/device/conv2d_with_reduction_testbed.h b/test/unit/conv/device/conv2d_with_reduction_testbed.h index 7c573f06..0a52679f 100644 --- a/test/unit/conv/device/conv2d_with_reduction_testbed.h +++ b/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -56,7 +56,7 @@ #include "cutlass/core_io.h" #include "cutlass/util/tensor_view_io.h" -#include "cache_testbed_output.h" +#include "../cache_testbed_output.h" namespace test { namespace conv { diff --git a/test/unit/conv/device/conv3d_testbed.h b/test/unit/conv/device/conv3d_testbed.h index 00c2eb1f..d4476ac0 100644 --- a/test/unit/conv/device/conv3d_testbed.h +++ b/test/unit/conv/device/conv3d_testbed.h @@ -55,7 +55,7 @@ #include "conv3d_problems.h" #include "cutlass/core_io.h" -#include "cache_testbed_output.h" +#include "../cache_testbed_output.h" namespace test { namespace conv { diff --git a/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h index 1c2506cd..9a2662ee 100644 --- a/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h +++ b/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h @@ -36,7 +36,7 @@ #include #include "../../common/cutlass_unit_test.h" -#include "cache_testbed_output.h" +#include "../cache_testbed_output.h" #include "conv2d_problems.h" #include "cutlass/conv/device/direct_convolution.h" @@ -466,8 +466,8 @@ bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace device -} // namespace conv -} // namespace test +} // namespace device +} // namespace conv +} // namespace test ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index 6c97ed7e..5de353c7 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -64,3 +64,10 @@ add_executable( cpp11.cu ) +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + target_compile_options( + cutlass_test_unit_core_cpp11 + PRIVATE + $<$:-Xcompiler -Werror> + ) +endif() diff --git a/test/unit/core/cpp11.cu b/test/unit/core/cpp11.cu index 42fdbcad..dcc79a69 100644 --- a/test/unit/core/cpp11.cu +++ b/test/unit/core/cpp11.cu @@ -42,6 +42,7 @@ #include +#include #include #include #include @@ -51,12 +52,12 @@ #include #include #include +#include #include #include #include #include #include - #include #include diff --git a/test/unit/core/fast_numeric_conversion.cu b/test/unit/core/fast_numeric_conversion.cu index ac3b5cf7..1057037a 100644 --- a/test/unit/core/fast_numeric_conversion.cu +++ b/test/unit/core/fast_numeric_conversion.cu @@ -147,6 +147,20 @@ TEST(FastNumericConversion, s32_to_f32) { test::core::kernel::run_test_integer_range_limited(); } +TEST(FastNumericConversion, s8_to_f32_array) { + int const kN = 256; + using Source = int8_t; + using Destination = float; + test::core::kernel::run_test_integer_range_all(); +} + +TEST(FastNumericConversion, u8_to_f32_array) { + int const kN = 256; + using Source = uint8_t; + using Destination = float; + test::core::kernel::run_test_integer_range_all(); +} + TEST(FastNumericConversion, s8_to_f16_array) { int const kN = 256; using Source = int8_t; diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index 63b132f3..20eeeb09 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -60,8 +60,8 @@ __global__ void convert( ///////////////////////////////////////////////////////////////////////////////////////////////// -template -void run_test(const char dest_name[], const char source_name[]) { +template +void run_test(const char dest_name[], const char source_name[], const int range = 4, const int offset = 0) { const int kN = Count; dim3 grid(1, 1); @@ -73,7 +73,7 @@ void run_test(const char dest_name[], const char source_name[]) { auto destination_ref = destination.host_ref(); for (int i = 0; i < kN; ++i) { - source_ref.at({0, i}) = Source(i % Range); + source_ref.at({0, i}) = Source(i % range + offset); } source.sync_device(); @@ -509,4 +509,160 @@ TEST(NumericConversion, int_to_fe4m3_t_array_32) { test::core::kernel::run_test(dest_name, source_name); } +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct GetName { + static constexpr char name[] = "UNSUPPORTED"; +}; + +template <> +struct GetName { + static constexpr char name[] = "int4b_t"; +}; + +template <> +struct GetName { + static constexpr char name[] = "uint8_t"; +}; + +template <> +struct GetName { + static constexpr char name[] = "int8_t"; +}; + +template <> +struct GetName { + static constexpr char name[] = "float_e4m3_t"; +}; + +template <> +struct GetName { + static constexpr char name[] = "half_t"; +}; + +template <> +struct GetName { + static constexpr char name[] = "bfloat16_t"; +}; + +template <> +struct GetName { + static constexpr char name[] = "float"; +}; + +template +struct ResultSourcePair { + using Result = Result_; + using Source = Source_; +}; + +template +class VectorArrayConverterTest : public testing::Test { + public: + using Result = typename ResultSourcePair::Result; + using Source = typename ResultSourcePair::Source; + + template + static void emit_test() { + const int range = 1 << cutlass::sizeof_bits::value; + const int offset = cutlass::platform::numeric_limits::lowest(); + test::core::kernel::run_test(GetName::name, GetName::name, range, offset); + } +}; + +using VectorConvertTypes = ::testing::Types< + ResultSourcePair, + ResultSourcePair, + + ResultSourcePair, + ResultSourcePair, + + ResultSourcePair, + ResultSourcePair, + + ResultSourcePair, + ResultSourcePair, + ResultSourcePair, + ResultSourcePair +>; + +TYPED_TEST_SUITE(VectorArrayConverterTest, VectorConvertTypes); + +TYPED_TEST(VectorArrayConverterTest, array_1) { + TestFixture::template emit_test<1>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_2) { + TestFixture::template emit_test<2>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_3) { + TestFixture::template emit_test<3>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_4) { + TestFixture::template emit_test<4>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_5) { + TestFixture::template emit_test<5>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_8) { + TestFixture::template emit_test<8>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_10) { + // N > 8 and N is not a multiple of 4 + TestFixture::template emit_test<10>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_12) { + // N > 8 and N is a multiple of 4 + TestFixture::template emit_test<12>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_16) { + // N > 8 and N is a multiple of 8 + TestFixture::template emit_test<16>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_17) { + // N > 8 and N is not a multiple of 8 + TestFixture::template emit_test<17>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_27) { + // Test entire conversion range with residue (for int4) + TestFixture::template emit_test<27>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_31) { + // Force use of converters for 16, 8, 4, 2 and scalar + // if max width is 16 + TestFixture::template emit_test<31>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_63) { + // Force use of converters for 32, 16, 8, 4, 2 and scalar + // if max width is 32 + TestFixture::template emit_test<63>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_256) { + // Test entire conversion range (for int8) + TestFixture::template emit_test<256>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_259) { + // Force use of 4, 2 and scalar converter (if max width is 4) + TestFixture::template emit_test<259>(); +} + +TYPED_TEST(VectorArrayConverterTest, array_263) { + // Force use of 8, 4, 2 and scalar converter (if max width is 8) + TestFixture::template emit_test<263>(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt index b333974d..9f14800c 100644 --- a/test/unit/cute/core/CMakeLists.txt +++ b/test/unit/cute/core/CMakeLists.txt @@ -42,10 +42,12 @@ cutlass_test_unit_add_executable( inverse_right.cpp logical_divide.cpp logical_product.cpp + math.cpp mixedbits.cpp nullspace.cpp pointer.cpp reverse.cpp transform.cpp tuple.cpp + int_tuple.cpp ) diff --git a/test/unit/cute/core/int_tuple.cpp b/test/unit/cute/core/int_tuple.cpp new file mode 100644 index 00000000..a801342a --- /dev/null +++ b/test/unit/cute/core/int_tuple.cpp @@ -0,0 +1,131 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +TEST(CuTe_core, WeaklyCongruent) +{ + using namespace cute; + + auto a = _1{}; + auto b = _2{}; + EXPECT_TRUE (weakly_congruent(a, a)); + EXPECT_TRUE (weakly_congruent(b, b)); + EXPECT_TRUE (weakly_congruent(a, b)); + + auto a0 = Shape<_1>{}; + auto b0 = Shape<_2>{}; + EXPECT_TRUE (weakly_congruent(a , a0)); + EXPECT_TRUE (weakly_congruent(b , b0)); + EXPECT_TRUE (weakly_congruent(a , b0)); + EXPECT_TRUE (weakly_congruent(b , a0)); + EXPECT_FALSE(weakly_congruent(a0, a )); + EXPECT_FALSE(weakly_congruent(b0, b )); + EXPECT_FALSE(weakly_congruent(a0, b )); + EXPECT_FALSE(weakly_congruent(b0, a )); + EXPECT_TRUE (weakly_congruent(a0, a0)); + EXPECT_TRUE (weakly_congruent(b0, b0)); + EXPECT_TRUE (weakly_congruent(a0, b0)); + + auto a1 = Shape<_1, _1>{}; + EXPECT_TRUE (weakly_congruent(a , a1)); + EXPECT_FALSE(weakly_congruent(a0, a1)); + EXPECT_TRUE (weakly_congruent(a1, a1)); + + auto a2 = Shape<_1, Shape<_1,_1>>{}; + EXPECT_TRUE (weakly_congruent(a , a2)); + EXPECT_FALSE(weakly_congruent(a0, a2)); + EXPECT_TRUE (weakly_congruent(a1, a2)); + + auto b1 = Shape<_2, _2>{}; + EXPECT_TRUE (weakly_congruent(b , b1)); + EXPECT_FALSE(weakly_congruent(b0, b1)); + EXPECT_TRUE (weakly_congruent(a1, b1)); + + auto b2 = Shape<_2, Shape<_2,_2>>{}; + EXPECT_FALSE(weakly_congruent(a2, b0)); + EXPECT_FALSE(weakly_congruent(a2, a1)); + EXPECT_TRUE (weakly_congruent(a2, b2)); + + auto b3 = Shape, Shape<_2,_2>>{}; + EXPECT_FALSE(weakly_congruent(a0, b3)); + EXPECT_TRUE (weakly_congruent(a1, b3)); + EXPECT_TRUE (weakly_congruent(a2, b3)); +} + +TEST(CuTe_core, WeaklyCompatible) +{ + using namespace cute; + + auto a = _16{}; + auto b = _12{}; + auto c = _8{}; + EXPECT_TRUE (weakly_compatible(a, a)); + EXPECT_TRUE (weakly_compatible(b, b)); + EXPECT_TRUE (weakly_compatible(c, c)); + EXPECT_FALSE(weakly_compatible(a, b)); + EXPECT_FALSE(weakly_compatible(a, c)); + EXPECT_TRUE (weakly_compatible(c, a)); + + auto a0 = Shape<_16>{}; + EXPECT_TRUE (weakly_compatible(a0, a0)); + EXPECT_TRUE (weakly_compatible(a , a0)); + EXPECT_FALSE(weakly_compatible(a0, a )); + EXPECT_TRUE (weakly_compatible(c , a0)); + EXPECT_FALSE(weakly_compatible(a0, c )); + EXPECT_FALSE(weakly_compatible(b , a0)); + EXPECT_FALSE(weakly_compatible(a0, b )); + + auto a1 = Shape<_2,_8>{}; + EXPECT_TRUE (weakly_compatible(a1, a1)); + EXPECT_TRUE (weakly_compatible(a , a1)); + EXPECT_FALSE(weakly_compatible(a0, a1)); + EXPECT_FALSE(weakly_compatible(a1, a0)); + EXPECT_TRUE (weakly_compatible(a1, Shape<_2,Shape<_2,_4>>{})); + + auto a2 = Shape>{}; + EXPECT_TRUE (weakly_compatible(a2, a2)); + EXPECT_TRUE (weakly_compatible(a , a2)); + EXPECT_TRUE (weakly_compatible(c , a2)); + EXPECT_TRUE (weakly_compatible(a0, a2)); + EXPECT_FALSE(weakly_compatible(a2, a0)); + + auto a3 = Shape>>{}; + EXPECT_TRUE (weakly_compatible(a3, a3)); + EXPECT_TRUE (weakly_compatible(a , a3)); + EXPECT_TRUE (weakly_compatible(c , a3)); + EXPECT_TRUE (weakly_compatible(a0, a3)); + EXPECT_FALSE(weakly_compatible(a3, a0)); + EXPECT_TRUE (weakly_compatible(a2, a3)); + EXPECT_FALSE(weakly_compatible(a3, a2)); +} diff --git a/test/unit/cute/core/math.cpp b/test/unit/cute/core/math.cpp new file mode 100644 index 00000000..1f3807a5 --- /dev/null +++ b/test/unit/cute/core/math.cpp @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include + +// If cute::gcd returns auto instead of common_type_t, +// then GCC 7.5 reports the following error; +// +// ... /include/cute/numeric/math.hpp:103:26: error: +// inconsistent deduction for auto return type: ‘int’ and then ‘bool’ +// if (u == 0) { return t; } +// ^ +// Note that common_type_t, C<1>>::value_type might still be bool. +TEST(CuTe_core, gcd_returns_common_type) +{ + using cute::C; + + constexpr auto fifteen = C<3 * 5>{}; + static_assert(cute::is_same_v); + static_assert(int(fifteen) == 15); + + constexpr auto forty_two = C<2 * 3 * 7>{}; + static_assert(cute::is_same_v); + static_assert(int(forty_two) == 42); + + // C<1>::value_type (as well as C<0>::value_type) may be bool. + constexpr auto one = C<1>{}; + + // Both inputs have value_type int. + { + constexpr auto result = cute::gcd(fifteen, forty_two); + static_assert(cute::is_same_v); + static_assert(int(result) == 3); + } + + // One input has value_type int, and the other may have value_type bool. + { + constexpr auto result = cute::gcd(one, forty_two); + static_assert(int(result) == 1); + } + { + constexpr auto result = cute::gcd(forty_two, one); + static_assert(int(result) == 1); + } + + // Both inputs may have value_type bool. + { + constexpr auto result = cute::gcd(one, one); + static_assert(int(result) == 1); + } +} + +TEST(CuTe_core, lcm_returns_common_type) +{ + using cute::C; + + constexpr auto six = C<2 * 3>{}; + static_assert(cute::is_same_v); + static_assert(int(six) == 6); + + constexpr auto fifteen = C<3 * 5>{}; + static_assert(cute::is_same_v); + static_assert(int(fifteen) == 15); + + // C<1>::value_type (as well as C<0>::value_type) may be bool. + constexpr auto one = C<1>{}; + + // Both inputs have value_type int. + { + constexpr auto result = cute::lcm(six, fifteen); + static_assert(cute::is_same_v); + static_assert(int(result) == 30); + } + + // One input has value_type int, and the other may have value_type bool. + { + constexpr auto result = cute::lcm(one, six); + static_assert(cute::is_same_v); + static_assert(int(result) == 6); + } + { + constexpr auto result = cute::lcm(six, one); + static_assert(cute::is_same_v); + static_assert(int(result) == 6); + } + + // Both inputs may have value_type bool. + { + constexpr auto result = cute::lcm(one, one); + static_assert(int(result) == 1); + } +} diff --git a/test/unit/cute/hopper/tma_load_testbed.hpp b/test/unit/cute/hopper/tma_load_testbed.hpp index 5e01345a..c3101261 100644 --- a/test/unit/cute/hopper/tma_load_testbed.hpp +++ b/test/unit/cute/hopper/tma_load_testbed.hpp @@ -29,6 +29,8 @@ * **************************************************************************************************/ +#pragma once + #include "cutlass_unit_test.h" #include diff --git a/test/unit/cute/hopper/tma_store_testbed.hpp b/test/unit/cute/hopper/tma_store_testbed.hpp index 47a31d9b..df0622db 100644 --- a/test/unit/cute/hopper/tma_store_testbed.hpp +++ b/test/unit/cute/hopper/tma_store_testbed.hpp @@ -29,6 +29,8 @@ * **************************************************************************************************/ +#pragma once + #include "cutlass_unit_test.h" #include diff --git a/test/unit/gemm/device/default_gemm_configuration.hpp b/test/unit/gemm/device/default_gemm_configuration.hpp index 96d78946..57f788ac 100644 --- a/test/unit/gemm/device/default_gemm_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_configuration.hpp @@ -169,7 +169,7 @@ struct DefaultGemmConfigurationToCutlass3Types< using TiledMma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group - Layout>>; // 1x2x1 value group for 16x16x16 MMA and LDSM + Tile<_32,_32,_16>>; // 32x32x16 MMA for LDSM, 1x2x1 value group // A static constexpr int kAlignmentA = 8; @@ -301,7 +301,7 @@ struct DefaultGemmConfigurationToCutlass3Types< using TiledMma = TiledMMA< MMA_Atom, Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group - Layout>>; // 1x2x1 value group for 16x16x8 and LDSM + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group // A static constexpr int kAlignmentA = 4; @@ -352,7 +352,7 @@ struct DefaultGemmConfigurationToCutlass3Types< using TiledMma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group - Layout>>; // 1x2x1 value group for 16x16x32 and LDSM + Tile<_32,_32,_32>>; // 16x16x32 MMA for LDSM, 1x2x1 value group // A (M,K) K-major using SmemLayoutAtomA = decltype( @@ -798,9 +798,9 @@ struct DefaultGemmConfigurationToCutlass3Types< using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom>, - Layout>, - Layout>, - Tile,Layout<_2,_16>,Underscore>>; + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>, // 32x32x1 MMA with perm for load vectorization + Layout,Stride<_2,_1>>,Underscore>>; // A (M,K) M-major using SmemLayoutAtomA = Layout>; @@ -920,9 +920,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom>, - Layout>, - Layout>, - Tile,Underscore,Underscore>>; + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore,Underscore>>; // 32x16x1 MMA with perm for load vectorization // A (M,K) M-major using SmemLayoutAtomA = Layout>; @@ -982,9 +981,8 @@ struct DefaultGemmConfigurationToCutlass3Types< using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom>, - Layout>, - Layout>, - Tile,Underscore>>; + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore>>; // 16x32x1 MMA with perm for load vectorization // A (M,K) K-major using SmemLayoutAtomA = Layout, @@ -1041,8 +1039,9 @@ struct DefaultGemmConfigurationToCutlass3Types< using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout - Layout>, // Val layout - Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; // A (M,K) K-Major using SmemLayoutAtomA = decltype( @@ -1119,8 +1118,9 @@ struct DefaultGemmConfigurationToCutlass3Types< using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout - Layout>, // Val layout - Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; // A (M,K) M-Major using SmemLayoutAtomA = decltype( @@ -1183,8 +1183,9 @@ struct DefaultGemmConfigurationToCutlass3Types< using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout - Layout>, // Val layout - Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; // A (M,K) M-Major using SmemLayoutAtomA = decltype( @@ -1247,8 +1248,9 @@ struct DefaultGemmConfigurationToCutlass3Types< using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout - Layout>, // Val layout - Tile,Layout<_2,_16>,Underscore>>; // Mode permutations + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; // A (M,K) K-Major using SmemLayoutAtomA = decltype( diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 86e146a5..394a66ae 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -58,6 +58,7 @@ #include "cutlass/fast_math.h" #include "cutlass/platform/platform.h" #include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" #include "cute/int_tuple.hpp" #include "cute/layout.hpp" @@ -192,6 +193,7 @@ struct TestbedImpl { using ActivationFunctor = ActivationFunctor_; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -248,6 +250,10 @@ struct TestbedImpl { // Used to force multi-wave tests for persistent kernel schedules constexpr static int MaxSmCount = 16; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + // // Methods // @@ -462,7 +468,7 @@ struct TestbedImpl { auto Vbeta = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, cute::_1{}))); - cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + cutlass::reference::host::GettMainloopParams mainloop_params{A, B, TransformA, TransformB}; cutlass::reference::host::GettEpilogueParams< ElementScalar, @@ -523,6 +529,9 @@ struct TestbedImpl { Gemm& gemm_op, typename Gemm::Arguments& arguments, cutlass::device_memory::allocation& workspace) { + int M = cute::size<0>(problem_size); + int N = cute::size<1>(problem_size); + int K = cute::size<2>(problem_size); int L = 1; if constexpr(cute::rank(ProblemShapeType{}) == 4) { L = cute::size<3>(problem_size); @@ -561,7 +570,8 @@ struct TestbedImpl { detail::Iterations iterations = detail::Iterations{}, RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, - detail::Splits splits = detail::Splits{}) + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic) { // Fail test if insufficient CUDA device if (!sufficient()) { @@ -586,14 +596,6 @@ struct TestbedImpl { hw_info.sm_count = this->sm_count; } - typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; - if constexpr (std::is_same_v) { - scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order }; - } - else { - scheduler_args = { static_cast(max_swizzle), raster_order }; - } - // DefaultEpilogue auto arguments = typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGemm, @@ -606,10 +608,20 @@ struct TestbedImpl { {alpha, beta}, tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d }, - hw_info, - scheduler_args + hw_info }; + if constexpr (std::is_same_v) { + arguments.scheduler.splits = static_cast(splits); + arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); + arguments.scheduler.raster_order = raster_order; + arguments.scheduler.decomposition_mode = decomposition_mode; + + } else { + arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); + arguments.scheduler.raster_order = raster_order; + } + Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -683,6 +695,7 @@ struct Testbed3x { using LayoutTagD = typename TestBedImpl::LayoutTagD; using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; // Detail Implementation TestBedImpl impl_; @@ -723,11 +736,12 @@ struct Testbed3x { RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, bool profiling = false, detail::Iterations iterations = detail::Iterations{}) { return impl_.run( - problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits + problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode ); } }; @@ -768,6 +782,7 @@ struct Testbed3xFusionOperation { static_assert(cute::is_base_of_v); using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; // fusion types are potentially void if the fusion is not supported // helper so we don't try to construct HostTensor with void type @@ -818,6 +833,7 @@ struct Testbed3xFusionOperation { cutlass::HostTensor abs_max_D; cutlass::HostTensor tensor_Aux; cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + // References cutlass::HostTensor reference_dbias; cutlass::HostTensor reference_Aux; @@ -977,7 +993,6 @@ struct Testbed3xFusionOperation { cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); } } - } template < @@ -1219,6 +1234,7 @@ struct Testbed3xFusionOperation { RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, bool profiling = false, detail::Iterations iterations = detail::Iterations{}) { @@ -1234,7 +1250,7 @@ struct Testbed3xFusionOperation { typename Gemm::Arguments arguments; cutlass::KernelHardwareInfo hw_info; cudaDeviceProp prop; - + hw_info.device_id = 0; if (not profiling) { impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); @@ -1251,11 +1267,6 @@ struct Testbed3xFusionOperation { /// A/B/C/D Tensor initialize(problem_size, alpha_, beta_); - typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; - if constexpr (std::is_same_v) { - scheduler_args = { static_cast(splits) }; - } - arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, @@ -1270,10 +1281,19 @@ struct Testbed3xFusionOperation { impl_.tensor_D.device_data(), impl_.stride_d }, // Epilogue arguments end - hw_info, - scheduler_args + hw_info }; - + + if constexpr (std::is_same_v) { + arguments.scheduler.splits = static_cast(splits); + arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); + arguments.scheduler.raster_order = raster_order; + arguments.scheduler.decomposition_mode = decomposition_mode; + } else { + arguments.scheduler.max_swizzle_size = static_cast(max_swizzle); + arguments.scheduler.raster_order = raster_order; + } + auto coord_0 = cutlass::make_Coord(0); if constexpr (IsLegacy) { arguments.epilogue.thread = { @@ -1313,12 +1333,18 @@ struct Testbed3xFusionOperation { } // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty if constexpr (cute::is_same_v>) { - // see ActivationFunctor::Arguments in activation.h for definition - // if Arguments doesn't exist then fusion_args.activation is empty fusion_args.activation.scale = ElementCompute(1); } + // Treat Clamp as ReLU + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = 0; + fusion_args.activation.upper_bound = std::numeric_limits::max(); + } + if constexpr (IsAbsMaxEnabledD) { fusion_args.amax_D_ptr = abs_max_D.device_data(); } @@ -1381,7 +1407,6 @@ struct Testbed3xFusionOperation { std::cout << "Error : Failed : with alpha: " << float(alpha_) << ", beta: " << float(beta_) << "\n"; } - return passed; } } @@ -1413,13 +1438,21 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + std::vector decomposition_modes = {DecompositionMode::Heuristic}; std::vector problem_splits = {1}; - if constexpr (std::is_same_v) { + static constexpr bool UsesStreamKScheduler = std::is_same_v; + if constexpr (UsesStreamKScheduler) { problem_splits.push_back(2); problem_splits.push_back(3); - // As many splits as there are maximum k tiles - problem_splits.push_back(Stages + 1); + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + + // Use larger K sizes for stream-K tests + static constexpr int min_tiles_per_sk_unit = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::min_iters_per_sk_unit_; + problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit - max_alignment}; } using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; @@ -1433,33 +1466,53 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { for (int k : problem_size_k) { for (auto raster_order : raster_orders) { for (int max_swizzle_size : max_swizzle_sizes) { - for (int splits : problem_splits) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } + for (DecompositionMode decomp_mode : decomposition_modes) { - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - raster_order, - detail::MaxSwizzleSize(max_swizzle_size), - detail::Splits(splits) - ); + std::vector problem_splits = {1}; + if (UsesStreamKScheduler && (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK)) { + auto max_splits = (k + TileShapeK - 1) / TileShapeK; + if (max_splits > 2) { + problem_splits.push_back(2); + } + if (max_splits > 3) { + problem_splits.push_back(3); + } - if (!passed) { - return false; + problem_splits.push_back(max_splits); + + // Test the case in which we ask for more splits than there are K tiles in the GEMM. In this + // case, split-K will fall back to a splitting factor of `max_splits`. + problem_splits.push_back(max_splits + 1); } - } - } - } - } - } - } + for (int splits : problem_splits) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, + detail::MaxSwizzleSize(max_swizzle_size), + detail::Splits(splits), + decomp_mode + ); + + if (!passed) { + return false; + } + } // splits + } // decomposition_mode + } // max_swizzle_size + } // raster_order + } // k + } // n + } // m // if we do support batched GEMM, just run one test on it to save on test time if constexpr (cute::rank(ProblemShapeType{}) == 4) { diff --git a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp index 1a21840d..90034d07 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -382,7 +382,7 @@ public: HostAuxLoad(){} template HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) - : Base(check_relative_equality), impl_(impl){ + : Base(check_relative_equality), impl_(impl) { auto problem_shape_NMKL = cute::append<4>(problem_size, 1); auto [_M, _N, K, _L] = problem_shape_NMKL; auto aux_coord = cutlass::make_Coord(_M * _L, _N); diff --git a/test/unit/gemm/device/sm90_evt_operations.hpp b/test/unit/gemm/device/sm90_evt_operations.hpp index 71c6f2bb..1a21fa38 100644 --- a/test/unit/gemm/device/sm90_evt_operations.hpp +++ b/test/unit/gemm/device/sm90_evt_operations.hpp @@ -267,7 +267,7 @@ template< using Sm90LinCombAuxLoad = Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc @@ -295,7 +295,7 @@ template< using Sm90LinCombEVTDAG = Sm90EVT, // beta * C + (alpha * acc + aux) Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90TopologicalVisitor< ElementCompute, cute::tuple< @@ -349,7 +349,7 @@ using Sm90LinCombDAGEVT = Sm90EVT, Sm90ScalarBroadcast, Sm90AccFetch, - Sm90SrcFetch + Sm90SrcFetch > >, Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias>, @@ -371,7 +371,7 @@ template< using Sm90LinCombPerColumnBias = Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc @@ -403,7 +403,7 @@ using Sm90LinCombPerColumnReduce = Sm90EVT, // per column reduce Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc Sm90ScalarBroadcast, // alpha Sm90AccFetch // acc @@ -428,7 +428,7 @@ using Sm90LinCombPerRowReduce = Sm90EVT, // per column reduce Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc Sm90ScalarBroadcast, // alpha Sm90AccFetch // acc @@ -452,7 +452,7 @@ using Sm90LinCombScalarReduce = Sm90EVT, // per column reduce Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta - Sm90SrcFetch, // C + Sm90SrcFetch, // C Sm90EVT, // alpha * acc Sm90ScalarBroadcast, // alpha Sm90AccFetch // acc diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu index ac90820d..408d9b31 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu @@ -389,7 +389,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< - LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -434,7 +434,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< - LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>; + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, cutlass::half_t, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -480,7 +480,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; // ReLU with uint1b_t aux will compute dReLU/dZ as the aux output, i.e. Aux(i) = (Z(i) >= 0) ? 1 : 0 using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< - LayoutC, cutlass::epilogue::thread::ReLU, cutlass::half_t, float, cutlass::uint1b_t, int8_t>; + LayoutC, cutlass::epilogue::thread::ReLU, cutlass::half_t, float, cutlass::uint1b_t, int8_t, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -525,7 +525,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias< - LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>; + LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -570,7 +570,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltAct< - LayoutC, cutlass::epilogue::thread::dGELU, cutlass::half_t, float, cutlass::half_t>; + LayoutC, cutlass::epilogue::thread::dGELU, cutlass::half_t, float, cutlass::half_t, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu index 18660318..83f03e6d 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu @@ -335,7 +335,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< - LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -380,7 +380,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< - LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>; + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, cutlass::half_t, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -425,7 +425,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< - LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, int8_t>; + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, int8_t, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -470,7 +470,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_pingpong_epilogue, 128x1 using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias< - LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>; + LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float, void>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu index fc4e3f31..ff4d4e2f 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu @@ -95,7 +95,8 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_1x2x1) { @@ -136,7 +137,8 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -178,7 +180,8 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -218,7 +221,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -258,7 +262,8 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -298,7 +303,8 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -341,7 +347,8 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -381,7 +388,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -421,7 +429,8 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -461,7 +470,8 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } @@ -505,7 +515,8 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -545,7 +556,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -585,7 +597,8 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -625,7 +638,8 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } @@ -669,7 +683,8 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -709,7 +724,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -749,7 +765,8 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -789,7 +806,8 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 256x128x64_2x2x1) { @@ -827,7 +845,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_cooperative_stream_k_epi >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 256x128x64_2x2x1) { @@ -865,7 +884,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_cooperative_stream_k_epi >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 128x128x64_2x2x1) { @@ -903,7 +923,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k_epi >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 128x128x64_2x2x1) { @@ -941,7 +962,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_stream_k_epi >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); } TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) { @@ -985,8 +1007,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_stream_k_epi using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); - EXPECT_TRUE(passed); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu index 45b9d023..3af6e029 100644 --- a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu @@ -99,7 +99,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_1x1x1) { @@ -146,7 +147,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -197,7 +199,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_1x2x1) { @@ -244,7 +247,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -295,7 +299,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_1x4x1) { @@ -342,7 +347,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -393,7 +399,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_4x1x1) { @@ -440,7 +447,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } /////////////////////////////////////////////////////////////////////////////// @@ -491,7 +499,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_2x4x1_fp8_fast_accum) { @@ -538,7 +547,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 0.0)); + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise(1.0, 1.0)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu index 9942f2d7..62008d26 100644 --- a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu +++ b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu @@ -101,8 +101,34 @@ test_scheduler( cutlass::KernelHardwareInfo hw_info{0, sm_count}; auto params = Scheduler::to_underlying_arguments(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, {splits}, nullptr); + typename Scheduler::Arguments args{}; + + // Set up the grid for the problem + dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args); + + auto print_info = [&]() { + std::cout << "Failed with problem size " + << size<0>(problem_shape_mnkl) << "x" + << size<1>(problem_shape_mnkl) << "x" + << size<2>(problem_shape_mnkl) << "x" + << size<3>(problem_shape_mnkl) + << " and grid size " << grid.x << "x" + << grid.y << "x" << grid.z + << " splits=" << params.splits_ + << " k_iter=" << params.divmod_tiles_per_output_tile_.divisor + << " big_units_=" << params.big_units_ + << " big_groups_=" << params.big_groups_ + << " sk_tiles=" << params.sk_tiles_ + << " sk_units=" << params.sk_units_ + << " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ + << " units_per_problem=" << params.units_per_problem_ + << " groups=" << params.divmod_sk_groups_.divisor << std::endl; + }; + // If we expect the schedule to be data-parallel only, ensure that no stream-K tiles are launched. if (expect_data_parallel && params.sk_tiles_ != 0) { + print_info(); + std::cout << "Expected stream-K to select a data-parallel decomposition." << std::endl; return false; } @@ -114,15 +140,11 @@ test_scheduler( // Initialize counters to zero cudaError_t err = cudaMemset((void*)visit_counters.get(), 0, sizeof(int) * total_counters); if (err != cudaSuccess) { - std::cerr << __FILE__ << ":" << __LINE__ << " cudaMemset failed with error: " << cudaGetErrorString(err) << std::endl; + print_info(); + std::cout << __FILE__ << ":" << __LINE__ << " cudaMemset failed with error: " << cudaGetErrorString(err) << std::endl; return false; } - typename Scheduler::Arguments args{}; - - // Set up the grid for the problem - dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args); - // Set up cluster and cluster launch. This is needed even for this simple kernel because // the SM90 scheduler needs to be able to query the CTA id within a cluster, which requires // explicitly launching with clusters. @@ -161,7 +183,8 @@ test_scheduler( err = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); if (err != cudaSuccess) { - std::cerr << __FILE__ << ":" << __LINE__ + print_info(); + std::cout << __FILE__ << ":" << __LINE__ << " cudaLaunchKernelExC failed with error: " << cudaGetErrorString(err) << std::endl; return false; @@ -169,7 +192,8 @@ test_scheduler( err = cudaDeviceSynchronize(); if (err != cudaSuccess) { - std::cerr << __FILE__ << ":" << __LINE__ + print_info(); + std::cout << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl; return false; @@ -181,20 +205,7 @@ test_scheduler( for (size_t i = 0; i < host_visit_counts.size(); ++i) { if (host_visit_counts[i] != 1) { - std::cout << "Failed with problem size " - << size<0>(problem_shape_mnkl) << "x" - << size<1>(problem_shape_mnkl) << "x" - << size<2>(problem_shape_mnkl) << "x" - << size<3>(problem_shape_mnkl) - << " and grid size " << grid.x << "x" - << grid.y << "x" << grid.z - << " splits=" << params.splits_ - << " k_iter=" << params.divmod_tiles_per_output_tile_.divisor - << " big_units=" << params.big_units_ - << " sk_tiles=" << params.sk_tiles_ - << " sk_units=" << params.sk_units_ - << " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ - << " units_per_problem=" << params.units_per_problem_ << std::endl; + print_info(); std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl; return false; } @@ -301,7 +312,7 @@ TEST(SM90_Device_Gemm_stream_k_scheduler, 256x128x64_2x1x1) { // Test various data-parallel cases EXPECT_TRUE(test_data_parallel(/*blocks_m=*/ 4, /*blocks_n=*/ 4, tile_shape, cluster_shape, /*sm_count=*/ 16)); EXPECT_TRUE(test_data_parallel(/*blocks_m=*/16, /*blocks_n=*/ 4, tile_shape, cluster_shape, /*sm_count=*/ 64)); - EXPECT_TRUE(test_data_parallel(/*blocks_m=*/ 4, /*blocks_n=*/27, tile_shape, cluster_shape, /*sm_count=*/108)); + EXPECT_TRUE(test_data_parallel(/*blocks_m=*/ 8, /*blocks_n=*/27, tile_shape, cluster_shape, /*sm_count=*/108)); // Test various stream-K cases EXPECT_TRUE(test_stream_k(tile_shape, cluster_shape, /*sm_count=*/ 16)); diff --git a/test/unit/nvrtc/thread/nvrtc_contraction.cu b/test/unit/nvrtc/thread/nvrtc_contraction.cu index 934523b5..80cf3a0d 100644 --- a/test/unit/nvrtc/thread/nvrtc_contraction.cu +++ b/test/unit/nvrtc/thread/nvrtc_contraction.cu @@ -43,6 +43,7 @@ static_assert(0, "CUDA include path is not defined"); #endif +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) TEST(SM90_nvrtc_kernel, Contraction) { static const char* nvrtc_opts[] = { "-w", @@ -62,5 +63,5 @@ TEST(SM90_nvrtc_kernel, Contraction) { { nvrtc_opts, nvrtc_opts + 5 } )); } - +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/pipeline/pipeline_tma_async.cu b/test/unit/pipeline/pipeline_tma_async.cu index 3253dfe2..f76c7911 100644 --- a/test/unit/pipeline/pipeline_tma_async.cu +++ b/test/unit/pipeline/pipeline_tma_async.cu @@ -60,10 +60,10 @@ using namespace cute; //////////////////// KERNEL ///////////////////////// -template +template struct SharedStorage { - typename cutlass::PipelineTmaAsync::SharedStorage storage; + typename cutlass::PipelineTmaAsync::SharedStorage storage; }; // Goal of this kernel is to complete deadlock-free @@ -73,10 +73,10 @@ void pipeline_device(uint32_t const NumIterations) { extern __shared__ char shared_memory[]; - using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipeline = cutlass::PipelineTmaAsync; using PipelineState = cutlass::PipelineState; - using SharedStorage = SharedStorage; + using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); [[maybe_unused]] auto cta_layout = Layout{}; // (m,n) -> cta_id @@ -98,7 +98,7 @@ void pipeline_device(uint32_t const NumIterations) params.is_leader = warp_group_thread_idx == 0; params.num_consumers = 128; - MainloopPipeline pipeline(shared_storage.storage, params); + MainloopPipeline pipeline(shared_storage.storage, params, cluster_shape); __syncthreads(); @@ -223,7 +223,7 @@ struct PipelineTest { } for (int iter = 0; iter < iterations; ++iter) { - int smem_size = int(sizeof(SharedStorage)); + int smem_size = int(sizeof(SharedStorage)); result = cudaFuncSetAttribute( pipeline_device, diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu index c6fa463a..3a3aecc6 100644 --- a/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu @@ -62,10 +62,10 @@ using namespace cutlass; //////////////////// KERNEL ///////////////////////// -template +template struct SharedStorage { - typename cutlass::PipelineTmaAsync::SharedStorage storage ; + typename cutlass::PipelineTmaAsync::SharedStorage storage ; }; struct KernelParams @@ -81,10 +81,10 @@ __global__ static void pipeline_device(KernelParams const kernel_params) { extern __shared__ char shared_memory[]; - using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; - using SharedStorage = SharedStorage; + using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); [[maybe_unused]] auto cta_layout = Layout{}; // (m,n) -> cta_id @@ -112,7 +112,7 @@ void pipeline_device(KernelParams const kernel_params) params.is_leader = warp_group_thread_idx == 0; params.num_consumers = 128; - MainloopPipeline pipeline(shared_storage.storage, params); + MainloopPipeline pipeline(shared_storage.storage, params, cluster_shape); __syncthreads(); @@ -292,9 +292,9 @@ struct PipelineTest { for (int iter = 0; iter < iterations; ++iter) { - using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; - int smem_size = int(sizeof(SharedStorage)); + int smem_size = int(sizeof(SharedStorage)); result = cudaFuncSetAttribute( pipeline_device, diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu index efb389be..32aa290b 100644 --- a/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu @@ -62,16 +62,16 @@ using namespace cutlass; //////////////////// KERNEL ///////////////////////// -template +template struct SharedStorage { - typename cutlass::PipelineTmaAsync::SharedStorage pipeline_storage; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_storage; typename PingPongBarrier::SharedStorage pingpong_storage; }; template struct CollectiveSimulation { - using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; CUTLASS_DEVICE @@ -198,7 +198,7 @@ __global__ static void pipeline_device(KernelParams params) { extern __shared__ char shared_memory[]; - using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; /* One for Mainloop and one for Epilogue */ @@ -206,7 +206,7 @@ void pipeline_device(KernelParams params) constexpr int MathWarpGroupCountPersistent = 2; using PingPongBarrier = typename cutlass::OrderedSequenceBarrier; - using SharedStorage = SharedStorage; + using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); [[maybe_unused]] auto cta_layout = Layout{}; // (m,n) -> cta_id @@ -232,7 +232,7 @@ void pipeline_device(KernelParams params) pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NumThreadsPerWarpGroup; - MainloopPipeline pipeline(shared_storage.pipeline_storage, pipeline_params); + MainloopPipeline pipeline(shared_storage.pipeline_storage, pipeline_params, cluster_shape); PipelineState tile_start_state_pipe; int tiles_per_cluster = params.tiles_per_cluster; @@ -343,7 +343,7 @@ struct PipelineTest { for (int iter = 0; iter < iterations; ++iter) { constexpr int StagesPerMathWarpGroup = 2; constexpr int MathWarpGroupCountPersistent = 2; - int smem_size = int(sizeof(SharedStorage>)); result = cudaFuncSetAttribute( diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index fd6a0a0f..6a1aa6b5 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -136,6 +136,8 @@ function(cutlass_add_cutlass_library) EXPORT_NAME ${__EXPORT_NAME} "" ) + + target_compile_features(${__NAME} INTERFACE cxx_std_17) set_target_properties( ${__NAME} @@ -159,6 +161,8 @@ function(cutlass_add_cutlass_library) EXPORT_NAME ${__EXPORT_NAME}_static "" ) + + target_compile_features(${__NAME}_static INTERFACE cxx_std_17) if (WIN32) set(STATIC_OUTPUT_NAME ${__OUTPUT_NAME}.static) @@ -196,8 +200,8 @@ function(cutlass_add_cutlass_library) # to the main cutlass library so users automatically get the necessary link # commands to pull in all kernels by default. - target_link_libraries(${DEFAULT_NAME} INTERFACE ${__NAME}) - target_link_libraries(${DEFAULT_NAME}_static INTERFACE ${__NAME}_static) + target_link_libraries(${DEFAULT_NAME} PUBLIC ${__NAME}) + target_link_libraries(${DEFAULT_NAME}_static PUBLIC ${__NAME}_static) endif() @@ -246,6 +250,7 @@ cutlass_add_cutlass_library( # For backward compatibility with the old name add_library(cutlass_lib ALIAS cutlass_library) +add_library(cutlass_lib_static ALIAS cutlass_library_static) ################################################################################ diff --git a/tools/library/include/cutlass/library/handle.h b/tools/library/include/cutlass/library/handle.h index 93070f31..21cf34c4 100644 --- a/tools/library/include/cutlass/library/handle.h +++ b/tools/library/include/cutlass/library/handle.h @@ -65,7 +65,7 @@ private: /// Size of device workspace in bytes size_t workspace_size_; - + /// Indicates whether scalars are host or device pointers ScalarPointerMode scalar_pointer_mode_; @@ -89,7 +89,7 @@ public: // // Persistent state accessors // - + /// Returns compute capability of the selected device int compute_capability() const; @@ -135,7 +135,7 @@ public: int K, /// GEMM K dimension NumericTypeID element_compute, /// Data type of internal accumulation - + NumericTypeID element_scalar, /// Data type of alpha/beta scalars void const *alpha, /// Pointer to alpha scalar @@ -164,7 +164,7 @@ public: void * ptr_D, /// Pointer to D matrix int64_t ldd /// Leading dimension of D matrix ); - + /// Executes a GEMM computation: D <= alpha * A*B + beta * C. // // Supports batched-strided, batched array or split-K serial or split-K parallel. @@ -176,7 +176,6 @@ public: int M, /// GEMM M dimension int N, /// GEMM N dimension int K, /// GEMM K dimension - NumericTypeID element_compute, /// Data type of internal accumulation NumericTypeID element_scalar, /// Data type of alpha/beta scalars @@ -218,7 +217,7 @@ public: /// Planar complex GEMM /// /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. - /// + /// Status gemm_planar_complex( int M, /// GEMM M dimension @@ -245,7 +244,7 @@ public: ComplexTransform transform_B, /// Complex transformation applied to B matrix void const * ptr_B_real, /// Pointer to real part of B matrix - void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix int64_t ldb_real, /// Leading dimension of real part of B matrix int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix @@ -301,7 +300,7 @@ public: ComplexTransform transform_A, /// Complex transformation applied to A matrix void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices - void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices int64_t lda_real, /// Leading dimension of real part of A matrix int64_t lda_imag, /// Leading dimension of imaginary part of A matrix diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 3c945d14..f572a2e8 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -28,17 +28,17 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! +/*! \file \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. Generally, - + description - compile-time constant parameters used to instantiate an operation - configuration - runtime parameters with computationally expensive initialization - + configuration - runtime parameters with computationally expensive initialization + arguments - runtime parameters that may be passed to an initialized operation with low computational overhead */ @@ -87,26 +87,26 @@ public: virtual OperationDescription const & description() const = 0; virtual Status can_implement( - void const *configuration, + void const *configuration, void const *arguments) const = 0; - + virtual uint64_t get_host_workspace_size( void const *configuration) const = 0; - + virtual uint64_t get_device_workspace_size( void const *configuration, void const *arguments = nullptr) const = 0; - + virtual Status initialize( - void const *configuration, - void *host_workspace, - void *device_workspace = nullptr, + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; virtual Status run( void const *arguments, - void *host_workspace, - void *device_workspace = nullptr, + void *host_workspace, + void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; }; @@ -217,7 +217,7 @@ using GemmBatchedArguments = GemmArguments; struct GemmArrayConfiguration { gemm::GemmCoord problem_size; - + /// Leading dimension of A matrix int64_t lda; @@ -241,7 +241,7 @@ struct GemmArrayArguments { void * const *D; void const *alpha; void const *beta; - ScalarPointerMode pointer_mode; + ScalarPointerMode pointer_mode; }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -264,7 +264,7 @@ struct GemmUniversalConfiguration { }; struct GemmUniversalArguments { - // NOTE: these are replicated for 3.0 interfaces + // NOTE: these are replicated for 3.0 interfaces gemm::GemmCoord problem_size; int batch_count; @@ -645,8 +645,8 @@ struct SymmArguments { struct Conv2dConfiguration { conv::SplitKMode split_k_mode; - - /// Conv2d problem size + + /// Conv2d problem size // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) // also includes (split_k_slices, groups) conv::Conv2dProblemSize problem_size; @@ -669,8 +669,8 @@ struct Conv2dConfiguration { struct Conv3dConfiguration { conv::SplitKMode split_k_mode; - - /// Conv2d problem size + + /// Conv2d problem size // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) // also includes (split_k_slices, groups) conv::Conv3dProblemSize problem_size; @@ -688,7 +688,7 @@ struct Conv3dConfiguration { layout::TensorNDHWC layout_output; // - // Methods + // Methods // // Mapping functions (A,B,C -> activation,filter,output) @@ -734,7 +734,7 @@ struct ConvArguments { /// pointer to reordered matrix B void const *reordered_B; - + /// pointer to implicit gemm matrix C void const *C; @@ -770,7 +770,7 @@ struct ReductionConfiguration { int64_t partition_stride; /// leading dimension of 'w'orkspace operand - int64_t ldw; + int64_t ldw; /// leading dimension of 's'ource operand int64_t lds; diff --git a/tools/library/include/cutlass/library/manifest.h b/tools/library/include/cutlass/library/manifest.h index abce958b..b844beca 100644 --- a/tools/library/include/cutlass/library/manifest.h +++ b/tools/library/include/cutlass/library/manifest.h @@ -90,7 +90,11 @@ public: Status release(); /// Appends an operation and takes ownership - void append(Operation *operation_ptr); + void append(Operation *operation_ptr) {\ + // This function is inline s.t. it is present in generated libraries + // without having to compile or link in manifest.cpp + operations_.emplace_back(operation_ptr); + } /// Returns an iterator to the first operation OperationVector const &operations() const; diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index 90ddac48..9f1031a8 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -257,11 +257,11 @@ protected: case RasterOrder::kAlongM: operator_args.scheduler.raster_order = Enum_t::AlongM; break; - default: + default: operator_args.scheduler.raster_order = Enum_t::Heuristic; } } - + return status; } @@ -271,7 +271,7 @@ public: Status can_implement( void const *configuration_ptr, void const *arguments_ptr) const override { - GemmUniversalConfiguration const *configuration = + GemmUniversalConfiguration const *configuration = static_cast(configuration_ptr); GemmUniversalArguments const *arguments = static_cast(arguments_ptr); diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index 38d53f4e..d0623039 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -32,7 +32,7 @@ /*! \file \brief CUTLASS Library handle. */ -#include +#include #include #include @@ -47,14 +47,14 @@ namespace library { /// Constructor Handle::Handle( - cudaStream_t stream, + cudaStream_t stream, size_t workspace_size ): - provider_(Provider::kCUTLASS), - stream_(stream), - workspace_(nullptr), - workspace_size_(0), - scalar_pointer_mode_(ScalarPointerMode::kHost), + provider_(Provider::kCUTLASS), + stream_(stream), + workspace_(nullptr), + workspace_size_(0), + scalar_pointer_mode_(ScalarPointerMode::kHost), last_operation_(nullptr) { int device_idx = -1; @@ -94,7 +94,7 @@ Handle::Handle(Handle && handle) { workspace_ = handle.workspace_; stream_ = handle.stream_; scalar_pointer_mode_ = handle.scalar_pointer_mode_; - + handle.workspace_ = nullptr; handle.workspace_size_ = 0; } @@ -156,14 +156,14 @@ void Handle::set_workspace_size(size_t bytes) { if (workspace_) { cudaFree(workspace_); } - + workspace_ = nullptr; workspace_size_ = bytes; if (workspace_size_) { - + cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_); - + if (error != cudaSuccess) { throw std::runtime_error("Failed to allocate workspace"); } @@ -239,7 +239,7 @@ static int gemm_problem_alignment( }; for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) { - + bool satisfied = true; // Can pointers satisfy this? @@ -260,7 +260,7 @@ static int gemm_problem_alignment( int max_element_alignment = 0; for (NumericTypeID type_id : elements) { - int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id); + int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id); max_element_alignment = std::max(max_element_alignment, element_alignment); } @@ -286,7 +286,7 @@ static int gemm_problem_alignment( /// Find the best kernel in descending order of preference. static Operation const * find_gemm_operation( - GemmOperationFunctionalMap::const_iterator operators_it, + GemmOperationFunctionalMap::const_iterator operators_it, GemmPreferenceKey const preference_key) { auto cc_it = operators_it->second.upper_bound(preference_key); @@ -363,7 +363,7 @@ Status Handle::gemm( void * ptr_D, /// Pointer to D matrix int64_t ldd /// Leading dimension of D matrix ) { - + // // Find the operation // @@ -390,7 +390,7 @@ Status Handle::gemm( if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { return cutlass::Status::kErrorNotSupported; } - + if (operators_it->second.empty()) { return cutlass::Status::kErrorNotSupported; } @@ -403,7 +403,7 @@ Status Handle::gemm( int const kMaximumAlignmentSize = 16; int alignment = gemm_problem_alignment( - M, N, K, + M, N, K, element_A, ptr_A, lda, 0, element_B, ptr_B, ldb, 0, element_C, ptr_C, ldc, 0, @@ -491,7 +491,6 @@ Status Handle::gemm_universal( int M, /// GEMM M dimension int N, /// GEMM N dimension int K, /// GEMM K dimension - NumericTypeID element_compute, /// Data type of internal accumulation NumericTypeID element_scalar, /// Data type of alpha/beta scalars @@ -529,7 +528,7 @@ Status Handle::gemm_universal( int64_t batch_stride_C, /// Batch stride of C operand int64_t batch_stride_D /// Batch stride of D operand ) { - + // // Find the operation // @@ -556,7 +555,7 @@ Status Handle::gemm_universal( if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { return cutlass::Status::kErrorNotSupported; } - + if (operators_it->second.empty()) { return cutlass::Status::kErrorNotSupported; } @@ -576,14 +575,14 @@ Status Handle::gemm_universal( // Ignore alignment of pointers to pointers. We can't check this from the host, // as each batch index has its own pointer in device memory. if (mode == GemmUniversalMode::kArray) { - ptr_A_check = nullptr; - ptr_B_check = nullptr; - ptr_C_check = nullptr; - ptr_D_check = nullptr; + ptr_A_check = nullptr; + ptr_B_check = nullptr; + ptr_C_check = nullptr; + ptr_D_check = nullptr; } int alignment = gemm_problem_alignment( - M, N, K, + M, N, K, element_A, ptr_A_check, lda, 0, element_B, ptr_B_check, ldb, 0, element_C, ptr_C_check, ldc, 0, @@ -758,7 +757,7 @@ Status Handle::gemm_planar_complex( if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { return cutlass::Status::kErrorNotSupported; } - + if (operators_it->second.empty()) { return cutlass::Status::kErrorNotSupported; } @@ -772,14 +771,14 @@ Status Handle::gemm_planar_complex( int alignment = std::max( gemm_problem_alignment( - M, N, K, + M, N, K, element_A, ptr_A_real, lda_real, batch_stride_A_real, element_B, ptr_B_real, ldb_real, batch_stride_B_real, element_C, ptr_C_real, ldc_real, batch_stride_C_real, ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize ), gemm_problem_alignment( - M, N, K, + M, N, K, element_A, ptr_A_imag, lda_imag, batch_stride_A_imag, element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag, element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag, @@ -928,7 +927,7 @@ Status Handle::gemm_planar_complex_array( int64_t ldd_real, /// Leading dimension of real part of D matrix int64_t ldd_imag /// Leading dimension of imaginary part of D matrix ) { - + // // Find the operation // @@ -955,7 +954,7 @@ Status Handle::gemm_planar_complex_array( if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { return cutlass::Status::kErrorNotSupported; } - + if (operators_it->second.empty()) { return cutlass::Status::kErrorNotSupported; } @@ -969,14 +968,14 @@ Status Handle::gemm_planar_complex_array( int alignment = std::max( gemm_problem_alignment( - expected_M, expected_N, expected_K, + expected_M, expected_N, expected_K, element_A, nullptr, lda_real, 0, element_B, nullptr, ldb_real, 0, element_C, nullptr, ldc_real, 0, nullptr, ldd_real, 0, kMaximumAlignmentSize ), gemm_problem_alignment( - expected_M, expected_N, expected_K, + expected_M, expected_N, expected_K, element_A, nullptr, lda_imag, 0, element_B, nullptr, ldb_imag, 0, element_C, nullptr, ldc_imag, 0, @@ -1066,7 +1065,7 @@ Status Handle::gemm_planar_complex_array( /// Finds conv operation instances with Conv::ElementC = Reduction::ElementWorkspace Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation) { - ConvDescription const &conv_desc = + ConvDescription const &conv_desc = static_cast(operation->description()); // if the curren conv operation accumulator and output data type match return operation @@ -1077,19 +1076,19 @@ Operation const* find_conv_operation_for_parallel_reduction(Operation const *ope // find conv operation to match conv output and reduction workspace data type ConvFunctionalKey key( library::Provider::kCUTLASS, - conv_desc.conv_kind, + conv_desc.conv_kind, conv_desc.A.element, conv_desc.A.layout, conv_desc.B.element, conv_desc.B.layout, conv_desc.tile_description.math_instruction.element_accumulator, conv_desc.C.layout, - conv_desc.tile_description.math_instruction.element_accumulator, + conv_desc.tile_description.math_instruction.element_accumulator, conv_desc.element_epilogue); // conv operation table for conv2d or conv3d - auto conv_operations = (conv_desc.kind == OperationKind::kConv2d) ? - Singleton::get().operation_table.conv2d_operations : + auto conv_operations = (conv_desc.kind == OperationKind::kConv2d) ? + Singleton::get().operation_table.conv2d_operations : Singleton::get().operation_table.conv3d_operations; // find ConvFunctionalKey in convolution operation table @@ -1098,18 +1097,18 @@ Operation const* find_conv_operation_for_parallel_reduction(Operation const *ope if (operators_it == conv_operations.end()) { return nullptr; } - + if (operators_it->second.empty()) { return nullptr; } // conv operation for same compute capability and iterator algorithm ConvPreferenceKey preference_key( - conv_desc.tile_description.minimum_compute_capability, + conv_desc.tile_description.minimum_compute_capability, conv_desc.iterator_algorithm); auto it = operators_it->second.find(preference_key); - + if(it == operators_it->second.end()) { return nullptr; } @@ -1129,7 +1128,7 @@ Operation const* find_conv_operation_for_parallel_reduction(Operation const *ope /// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) { - GemmDescription const &gemm_desc = + GemmDescription const &gemm_desc = static_cast(operation->description()); // if the curren gemm operation accumulator and output data type match return operation @@ -1174,7 +1173,7 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope gemm_desc.B.alignment); GemmPreferenceKey preference_key( - gemm_desc.tile_description.minimum_compute_capability, + gemm_desc.tile_description.minimum_compute_capability, alignment); auto it = operators_it->second.find(preference_key); diff --git a/tools/library/src/manifest.cpp b/tools/library/src/manifest.cpp index 1f3c456a..31313488 100644 --- a/tools/library/src/manifest.cpp +++ b/tools/library/src/manifest.cpp @@ -77,11 +77,6 @@ Status Manifest::release() { return Status::kSuccess; } -/// Appends an operation and takes ownership -void Manifest::append(Operation *operation_ptr) { - operations_.emplace_back(operation_ptr); -} - /// Returns an iterator to the first operation OperationVector const & Manifest::operations() const { return operations_; diff --git a/tools/library/src/reference/conv_reference_operation.h b/tools/library/src/reference/conv_reference_operation.h index 4eac5deb..2d830add 100644 --- a/tools/library/src/reference/conv_reference_operation.h +++ b/tools/library/src/reference/conv_reference_operation.h @@ -45,6 +45,7 @@ #include "cutlass/library/util.h" #include "library_internal.h" +#include "cutlass/conv/convolution.h" #include "cutlass/util/reference/host/convolution.h" #include "cutlass/util/reference/device/convolution.h" @@ -59,7 +60,7 @@ namespace detail { template < Provider kProvider, - conv::Operator ConvolutionalOperator, + cutlass::conv::Operator ConvolutionalOperator, int ConvDim, typename ElementA_, typename LayoutA_, @@ -77,7 +78,7 @@ struct ConvReferenceDispatcher; /// Dispatcher for Conv2d (partially specialized for kConvDim == 2) template < Provider kProvider, - conv::Operator kConvolutionalOperator, + cutlass::conv::Operator kConvolutionalOperator, typename ElementA, typename LayoutA, typename ElementB, @@ -193,7 +194,7 @@ struct ConvReferenceDispatcher< /// Dispatcher for Conv3d (partially specialized for kConvDim == 3) template < Provider kProvider, - conv::Operator kConvolutionalOperator, + cutlass::conv::Operator kConvolutionalOperator, typename ElementA, typename LayoutA, typename ElementB, @@ -292,7 +293,7 @@ struct ConvReferenceDispatcher< template < Provider Provider_, - conv::Operator ConvolutionalOperator, + cutlass::conv::Operator ConvolutionalOperator, int ConvDim, typename ElementA_, typename LayoutA_, @@ -308,7 +309,7 @@ template < class ConvReferenceOperation : public Operation { public: static Provider const kProvider = Provider_; - static conv::Operator const kConvolutionalOperator = ConvolutionalOperator; + static cutlass::conv::Operator const kConvolutionalOperator = ConvolutionalOperator; static int const kConvDim = ConvDim; using ElementA = ElementA_; @@ -491,7 +492,7 @@ void make_conv_fprop(Manifest &manifest) { manifest.append(new ConvReferenceOperation< Provider::kReferenceHost, - conv::Operator::kFprop, + cutlass::conv::Operator::kFprop, kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_, @@ -504,7 +505,7 @@ void make_conv_fprop(Manifest &manifest) { manifest.append(new ConvReferenceOperation< Provider::kReferenceDevice, - conv::Operator::kFprop, + cutlass::conv::Operator::kFprop, kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_, @@ -534,7 +535,7 @@ void make_conv_backwards(Manifest &manifest) { manifest.append(new ConvReferenceOperation< Provider::kReferenceHost, - conv::Operator::kDgrad, + cutlass::conv::Operator::kDgrad, kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_, @@ -547,7 +548,7 @@ void make_conv_backwards(Manifest &manifest) { manifest.append(new ConvReferenceOperation< Provider::kReferenceDevice, - conv::Operator::kDgrad, + cutlass::conv::Operator::kDgrad, kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_, @@ -560,7 +561,7 @@ void make_conv_backwards(Manifest &manifest) { manifest.append(new ConvReferenceOperation< Provider::kReferenceHost, - conv::Operator::kWgrad, + cutlass::conv::Operator::kWgrad, kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_, @@ -573,7 +574,7 @@ void make_conv_backwards(Manifest &manifest) { manifest.append(new ConvReferenceOperation< Provider::kReferenceDevice, - conv::Operator::kWgrad, + cutlass::conv::Operator::kWgrad, kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_, diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 28914a69..fe5aae0b 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -67,11 +67,12 @@ public: /// Problem structure obtained from problem space struct GemmProblem { - cutlass::library::GemmUniversalMode mode; + cutlass::library::GemmUniversalMode mode; int64_t m; int64_t n; int64_t k; + int64_t lda; int64_t ldb; int64_t ldc; @@ -93,9 +94,16 @@ public: // Methods // - GemmProblem(): + GemmProblem(): mode(library::GemmUniversalMode::kGemm), - m(16), n(16), k(16), lda(0), ldb(0), ldc(0), split_k_slices(1), batch_count(1), + m(16), + n(16), + k(16), + lda(0), + ldb(0), + ldc(0), + split_k_slices(1), + batch_count(1), raster_order(cutlass::library::RasterOrder::kHeuristic){ } /// Parses the problem @@ -117,7 +125,7 @@ public: ProblemSpace const &problem_space); }; - /// Workspace used + /// Workspace used struct GemmWorkspace { DeviceAllocation *A; @@ -150,7 +158,7 @@ public: // Methods // - GemmWorkspace(): + GemmWorkspace(): A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr), problem_count(1) { } }; @@ -163,7 +171,7 @@ protected: /// GEMM problem obtained from problem space GemmProblem problem_; - /// Device memory allocations + /// Device memory allocations GemmWorkspace gemm_workspace_; /// CUTLASS parallel reduction operation to follow this* gemm operation @@ -190,8 +198,8 @@ public: /// Extracts the problem dimensions virtual Status initialize_configuration( - Options const &options, - PerformanceReport &report, + Options const &options, + PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, @@ -199,8 +207,8 @@ public: /// Initializes workspace virtual Status initialize_workspace( - Options const &options, - PerformanceReport &report, + Options const &options, + PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, @@ -208,7 +216,7 @@ public: /// Verifies CUTLASS against references virtual bool verify_cutlass( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, @@ -217,8 +225,8 @@ public: /// Measures performance results virtual bool profile( - Options const &options, - PerformanceReport &report, + Options const &options, + PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, @@ -229,13 +237,13 @@ protected: /// Initializes the performance result void initialize_result_( PerformanceResult &result, - Options const &options, + Options const &options, library::GemmDescription const &operation_desc, ProblemSpace const &problem_space); /// Verifies CUTLASS against references bool verify_with_cublas_( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, @@ -244,7 +252,7 @@ protected: /// Verifies CUTLASS against host and device references bool verify_with_reference_( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index f16ccf7d..0419692f 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -1493,7 +1493,6 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); - case library::NumericTypeID::kF16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -1633,7 +1632,7 @@ bool DeviceAllocation::block_compare_equal( capacity); default: - throw std::runtime_error("Unsupported numeric type"); + throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type)); } } @@ -1662,7 +1661,6 @@ bool DeviceAllocation::block_compare_relatively_equal( capacity, static_cast(epsilon), static_cast(nonzero_floor)); - case library::NumericTypeID::kF16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), @@ -2089,8 +2087,12 @@ void DeviceAllocation::write_tensor_csv( write_tensor_csv_static_type >(out, *this); break; + case library::NumericTypeID::kVoid: + // Not dump anything as it is a empty tensor. + break; + default: - throw std::runtime_error("Unsupported numeric type"); + throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type()) ) ; } } @@ -2168,7 +2170,6 @@ void DeviceAllocation::fill(double val = 0.0) { case library::NumericTypeID::kFE5M2: tensor_fill(*this, static_cast(val)); break; - case library::NumericTypeID::kF16: tensor_fill(*this, static_cast(val)); break; @@ -2254,7 +2255,7 @@ void DeviceAllocation::fill(double val = 0.0) { break; default: - throw std::runtime_error("Unsupported numeric type"); + throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(this->type())); } } diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index f50b4d4a..835b3911 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -55,7 +55,7 @@ namespace profiler { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Ctor -GemmOperationProfiler::GemmOperationProfiler(Options const &options): +GemmOperationProfiler::GemmOperationProfiler(Options const &options): OperationProfiler( options, library::OperationKind::kGemm, @@ -73,7 +73,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"}, {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, - {ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"}, + {ArgumentTypeID::kEnumerated, {"raster_order", "raster-order"}, "Raster order (heuristic, along_n, along_m)"}, }, { library::Provider::kCUBLAS} ) { @@ -119,7 +119,7 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const { << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" << " $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" - + << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" << " $ cutlass_profiler --operation=Gemm \\ \n" << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" @@ -150,9 +150,9 @@ Status GemmOperationProfiler::GemmProblem::parse( library::GemmDescription const &operation_desc, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - + this->mode = library::GemmUniversalMode::kGemm; - + if (!arg_as_int(this->m, "m", problem_space, problem)) { // default value this->m = 1024; @@ -162,17 +162,17 @@ Status GemmOperationProfiler::GemmProblem::parse( // default value this->n = 1024; } - + if (!arg_as_int(this->k, "k", problem_space, problem)) { // default value this->k = 1024; } - + if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) { // default value this->split_k_mode = library::SplitKMode::kSerial; } - + this->mode = library::GemmUniversalMode::kGemm; if (this->split_k_mode == library::SplitKMode::kParallel) { this->mode = library::GemmUniversalMode::kGemmSplitKParallel; @@ -182,7 +182,7 @@ Status GemmOperationProfiler::GemmProblem::parse( // default value this->split_k_slices = 1; } - + if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { // default value this->batch_count = 1; @@ -194,7 +194,7 @@ Status GemmOperationProfiler::GemmProblem::parse( // default value this->raster_order = library::RasterOrder::kHeuristic; } - + if (this->split_k_slices > 1 && this->batch_count > 1) { // At least one of these must be one return Status::kErrorInvalidProblem; @@ -217,24 +217,24 @@ Status GemmOperationProfiler::GemmProblem::parse( } if (!arg_as_scalar( - this->alpha, - operation_desc.element_epilogue, - "alpha", - problem_space, + this->alpha, + operation_desc.element_epilogue, + "alpha", + problem_space, problem)) { if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { return Status::kErrorInternal; } } - + if (!arg_as_scalar( - this->beta, - operation_desc.element_epilogue, - "beta", - problem_space, + this->beta, + operation_desc.element_epilogue, + "beta", + problem_space, problem)) { - + if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { return Status::kErrorInternal; } @@ -327,7 +327,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); set_argument(result, "split_k_slices", problem_space, split_k_slices); set_argument(result, "batch_count", problem_space, batch_count); - set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); + set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); set_argument(result, "alpha", problem_space, library::lexical_cast(alpha, operation_desc.element_epilogue)); @@ -339,14 +339,14 @@ void GemmOperationProfiler::GemmProblem::initialize_result( /// Extracts the problem dimensions Status GemmOperationProfiler::initialize_configuration( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - library::GemmDescription const &operation_desc = + library::GemmDescription const &operation_desc = static_cast(operation->description()); if (operation_desc.gemm_kind != library::GemmKind::kUniversal) { @@ -383,7 +383,6 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; gemm_workspace_.arguments.raster_order = problem_.raster_order; - // initialize reduction operation for parallel splitKMode if (problem_.split_k_mode == library::SplitKMode::kParallel) { if (!initialize_reduction_configuration_(operation, problem)) { @@ -392,14 +391,14 @@ Status GemmOperationProfiler::initialize_configuration( } initialize_result_(this->model_result_, options, operation_desc, problem_space); - + return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); } /// Initializes the performance result void GemmOperationProfiler::initialize_result_( PerformanceResult &result, - Options const &options, + Options const &options, library::GemmDescription const &operation_desc, ProblemSpace const &problem_space) { @@ -451,7 +450,7 @@ bool GemmOperationProfiler::initialize_reduction_configuration_( ); auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key); - + if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) { return false; } @@ -465,7 +464,7 @@ bool GemmOperationProfiler::initialize_reduction_configuration_( /// Initializes workspace Status GemmOperationProfiler::initialize_workspace( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, @@ -480,14 +479,14 @@ Status GemmOperationProfiler::initialize_workspace( } } - library::GemmDescription const &operation_desc = + library::GemmDescription const &operation_desc = static_cast(operation->description()); // Compute the number of copies of the problem to avoid L2 camping. if (!options.profiling.workspace_count) { int64_t bytes = problem_.bytes(operation_desc); if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { - gemm_workspace_.problem_count = + gemm_workspace_.problem_count = 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); } else { @@ -629,7 +628,7 @@ Status GemmOperationProfiler::initialize_workspace( /// Verifies CUTLASS against references bool GemmOperationProfiler::verify_cutlass( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, @@ -685,7 +684,7 @@ bool GemmOperationProfiler::verify_cutlass( } results_.back().status = underlying_operation->run( - &gemm_workspace_.arguments, + &gemm_workspace_.arguments, gemm_workspace_.host_workspace.data(), gemm_workspace_.device_workspace.data()); @@ -748,8 +747,8 @@ bool GemmOperationProfiler::verify_cutlass( #endif // #if CUTLASS_ENABLE_CUBLAS bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem); - - // Update disposition to worst case verification outcome among all + + // Update disposition to worst case verification outcome among all // verification providers which are supported bool is_any_verification_run_passed = false; for (auto &m : results_.back().verification_map) { @@ -788,7 +787,7 @@ bool GemmOperationProfiler::verify_cutlass( /// Verifies CUTLASS against references bool GemmOperationProfiler::verify_with_cublas_( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, @@ -798,13 +797,13 @@ bool GemmOperationProfiler::verify_with_cublas_( #if CUTLASS_ENABLE_CUBLAS - library::GemmDescription const &gemm_desc = + library::GemmDescription const &gemm_desc = static_cast(operation->description()); // // Construct cuBLAS operators // - + CublasCreate handle; cublasStatus_t status = handle.get_cublas_create_status(); @@ -817,8 +816,8 @@ bool GemmOperationProfiler::verify_with_cublas_( std::vector algorithms; detail::select_cublas_algorithms( - algorithms, - options, + algorithms, + options, gemm_desc); if (algorithms.empty()) { @@ -849,8 +848,8 @@ bool GemmOperationProfiler::verify_with_cublas_( gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - detail::cublasGemmExDispatcher gemm_op( - gemm_desc, + detail::cublasGemmExDispatcher gemm_op( + gemm_desc, gemm_workspace_.configuration, gemm_workspace_.arguments, algorithms.front() @@ -884,7 +883,7 @@ bool GemmOperationProfiler::verify_with_cublas_( ); // Save workspace if incorrect - if (options.verification.save_workspace == SaveWorkspace::kIncorrect && + if (options.verification.save_workspace == SaveWorkspace::kIncorrect && results_.back().verification_map[library::Provider::kCUBLAS] == Disposition::kIncorrect) { save_workspace( @@ -909,14 +908,14 @@ bool GemmOperationProfiler::verify_with_cublas_( /// Verifies CUTLASS against host and device references bool GemmOperationProfiler::verify_with_reference_( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { - library::GemmDescription const &gemm_desc = + library::GemmDescription const &gemm_desc = static_cast(operation->description()); // @@ -1016,7 +1015,7 @@ bool GemmOperationProfiler::verify_with_reference_( results_.back().status = status; if (provider == library::Provider::kReferenceHost) { - gemm_workspace_.Reference->copy_from_host(ptr_D); + gemm_workspace_.Reference->copy_from_host(ptr_D); } // @@ -1031,7 +1030,7 @@ bool GemmOperationProfiler::verify_with_reference_( ); // Save workspace if incorrect - if (options.verification.save_workspace == SaveWorkspace::kIncorrect && + if (options.verification.save_workspace == SaveWorkspace::kIncorrect && results_.back().verification_map[provider] == Disposition::kIncorrect) { save_workspace( @@ -1050,7 +1049,7 @@ bool GemmOperationProfiler::verify_with_reference_( /// Measures performance results bool GemmOperationProfiler::profile( - Options const &options, + Options const &options, PerformanceReport &report, DeviceContext &device_context, library::Operation const *operation, @@ -1131,7 +1130,7 @@ Status GemmOperationProfiler::profile_cutlass_( Status status; for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { - + int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); @@ -1184,7 +1183,7 @@ Status GemmOperationProfiler::profile_cutlass_( int iteration = 0; for (; iteration < Iterations; ++iteration) { - + // Iterate over copies of the problem in memory int workspace_idx = options.profiling.warmup_iterations + iteration; int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count; diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index 7592c81a..ae67c550 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -181,7 +181,7 @@ public: device_.reset(); host_.clear(); - count = count / kElementsPerStoredVec * kNumStoragePerStoredVec; + count = (count + kElementsPerStoredVec - 1) / kElementsPerStoredVec * kNumStoragePerStoredVec; host_.resize(count); // Allocate memory diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index b21582e0..13f3a5b4 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -45,6 +45,7 @@ namespace cutlass { // Strides without batch mode template +CUTLASS_HOST_DEVICE cute::Stride> make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { static_assert(std::is_integral_v, @@ -55,6 +56,7 @@ make_cute_packed_stride(cute::Stride> s, cute::Shape +CUTLASS_HOST_DEVICE cute::Stride, IntT> make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { static_assert(std::is_integral_v, @@ -69,6 +71,7 @@ make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape +CUTLASS_HOST_DEVICE cute::Stride, int64_t> make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { static_assert(std::is_integral_v, @@ -86,6 +89,7 @@ make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape } template +CUTLASS_HOST_DEVICE cute::Stride, IntT, int64_t> make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { static_assert(std::is_integral_v, diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 5e75d08b..70c8c0da 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -257,16 +257,19 @@ void gett_epilogue( using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; constexpr bool IsScalingAndAmaxOutputNeeded = - std::is_same_v or - std::is_same_v; + cute::is_same_v or + cute::is_same_v; constexpr bool IsScalingAndAmaxAuxOutputNeeded = - std::is_same_v or - std::is_same_v; + cute::is_same_v or + cute::is_same_v; constexpr bool IsReLUAuxNeeded = - cute::is_same_v> and + (cute::is_same_v> or + cute::is_same_v>) and cute::is_same_v; + constexpr bool IsClamp = + cute::is_same_v>; constexpr bool IsBackpropFusion = cute::is_same_v> or @@ -276,7 +279,7 @@ void gett_epilogue( NumericConverter accumulator_converter; NumericConverter source_converter; NumericConverter bias_converter; - NumericConverter aux_source_converter; + [[maybe_unused]] NumericConverter aux_source_converter; // Scale related converter NumericConverter scale_converter; @@ -369,7 +372,12 @@ void gett_epilogue( } } - output = activation(output); + if constexpr (IsClamp) { // Treat Clamp as ReLU + output = activation(output, {0, std::numeric_limits::max()}); + } + else { + output = activation(output); + } } if constexpr (IsScalingAndAmaxOutputNeeded) { @@ -436,14 +444,14 @@ void Gemm3x( static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { - Layout layout_A = make_layout_rank3(mainloop_params.A); - Layout layout_B = make_layout_rank3(mainloop_params.B); - Layout layout_C = make_layout_rank3(epilogue_params.C); - Layout layout_D = make_layout_rank3(epilogue_params.D); - Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); - Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); - Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); - Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); + cute::Layout layout_A = make_layout_rank3(mainloop_params.A); + cute::Layout layout_B = make_layout_rank3(mainloop_params.B); + cute::Layout layout_C = make_layout_rank3(epilogue_params.C); + cute::Layout layout_D = make_layout_rank3(epilogue_params.D); + cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); + cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); + cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); + cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); auto TensorB = make_tensor(mainloop_params.B.data(), layout_B);