diff --git a/CHANGELOG.md b/CHANGELOG.md index 95419bcb..63e3e80e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ - [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp). - [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp). - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./cutlass/media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). + - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. * Full support for Blackwell SM100 kernels in CUTLASS 3.x API: - [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that @@ -32,6 +32,7 @@ * CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification. - Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes. - Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors. +* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels. * Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture: - [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API. - GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell. @@ -46,14 +47,15 @@ - [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128. * Documentation updates: - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). - - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/narrow_and_mixed_precision_gemms.md) + - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md) - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). + - Support grouped GEMM in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details). ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) - [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). - [Distributed GEMM](./examples/65_distributed_gemm/65_distributed_gemm.cu) is a new (experimental) API which can turn existing CUTLASS GEMM kernels into pipelined Tensor Parallel GEMMs that run efficiently on NVLink-based network of GPUs. Its pipelining schedules can hide most of the communication behind computation, and relies on point-to-point communication, which can simply use CUDA runtime's peer device access feature. It also utilizes remote TMA loads and memcopies with CUDA graphs to handle communication primarily through the Copy Engine, leaving all SMs free for Hopper's persistent kernels. For more details you can refer to the [DistGEMM blog post](https://blog.shi-labs.com/distributed-gemm-88be6a481e2b). -- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- Improved persistent grid launch for Hopper kernels with large cluster sizes (>= size of 4) using the new `make_kernel_hardware_info` API as shown in [example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). - Enabled high precision accumulation for Hopper FP8 Sparse GEMM. - Potential API breaking changes: + Fix `cute::UniversalCopy` for type safety. diff --git a/CMakeLists.txt b/CMakeLists.txt index 9892f067..b9de4f96 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -114,6 +114,13 @@ set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) ################################################################################ + + +include(customConfigs.cmake) + +################################################################################ + + set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") if(CUTLASS_ENABLE_HEADERS_ONLY) @@ -395,12 +402,6 @@ endif() # ################################################################################################### -if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8) - list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUDA_BLACKWELL_TMA_SWIZZLE_ENABLED=1) - - list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUDA_ENABLE_PREFERRED_CLUSTER=1) -endif() - # Warnings-as-error exceptions and warning suppressions for Clang builds @@ -978,6 +979,94 @@ function(cutlass_add_executable_tests NAME TARGET) endfunction() + + +function(cutlass_generate_profiler_tests NAME) + + set(options) + set(oneValueArgs) + set(multiValueArgs DEPENDS DEPENDEES CUTLASS_PROFILER_EXTRA_OPTIONS) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if (NOT CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS AND NOT CUTLASS_BUILD_FOR_PROFILER_PERFORMANCE_REGRESSIONS) + return() + endif() + + install( + FILES ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} + DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass/ + RENAME profiler_regressions.csv + ) + + # Generate cmake test targets for each entry in the testlist csv + + if (NOT EXISTS "${CUTLASS_PROFILER_REGRESSION_LIST_FILE}") + message(SEND_ERROR "Profiler unit tests list path is invalid: CUTLASS_PROFILER_REGRESSION_LIST_FILE = ${CUTLASS_PROFILER_REGRESSION_LIST_FILE}") + else() + message(STATUS "Using ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} to generate profiler-based tests.") + endif() + + file(STRINGS ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} TEST_LIST) + + foreach(TEST IN LISTS TEST_LIST) + + if ("${TEST}" MATCHES " *cutlass_profiler.*") + + # Generate a flattened name for the test from the test command line. + string(REPLACE "," ";" TEST_NAME_LIST ${TEST}) + list(GET TEST_NAME_LIST 0 TEST) + string(REGEX MATCHALL "[a-zA-Z0-9_=]+" TEST_NAME "${TEST}") + list(FILTER TEST_NAME EXCLUDE REGEX "cutlass_profiler|mode=trace|providers=cutlass") + list(JOIN TEST_NAME "_" TEST_NAME) + string(REGEX REPLACE "_verification_required=(true|false)" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_verification_providers=device" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "batch_count=" "batch" TEST_NAME "${TEST_NAME}") + string(REPLACE "cluster_m=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_cluster_n=" "x" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "_cluster_k=[0-9]+" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "cluster_m_fallback=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_cluster_n_fallback=" "x" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_error_on_no_match" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_error_if_nothing_is_profiled" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "kernels" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "operation" "" TEST_NAME "${TEST_NAME}") + + if (__DO_NOT_LOWERCASE_TEST_NAME) + string(TEST_NAME_LOWER "${TEST_NAME}") + else() + string(TOLOWER "${TEST_NAME}" TEST_NAME_LOWER) + endif() + + # Munge the test command + string(REPLACE "cutlass_profiler" "" TEST "${TEST}") + set(TEST "${TEST}" ${__CUTLASS_PROFILER_EXTRA_OPTIONS} "--junit-output=${TEST_NAME_LOWER}") + set(TEST_COMMAND_${TEST_NAME_LOWER} "${TEST}") + list(APPEND TEST_COMMAND_VARS ${TEST_NAME_LOWER}) + + endif() + + endforeach() + + cutlass_add_executable_tests( + ${NAME} cutlass_profiler + DEPENDS ${__DEPENDS} + DEPENDEES ${__DEPENDEES} + TEST_COMMAND_OPTIONS ${TEST_COMMAND_VARS} + TEST_COMMAND_OPTIONS_PREFIX TEST_COMMAND_ + DISABLE_EXECUTABLE_INSTALL_RULE + # Uncomment the following line when alloc/dealloc tracking + # is fixed for all configurations. + # TEST_SETS_SUPPORTED tmem_alloc_tracking + ) + +endfunction() + + + if (CUTLASS_ENABLE_TOOLS) add_subdirectory(tools) if (CUTLASS_ENABLE_PROFILER) diff --git a/README.md b/README.md index e29a1f85..f9f23c08 100644 --- a/README.md +++ b/README.md @@ -87,11 +87,11 @@ For a background on Blackwell's new features, please consult the PTX documentati - [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128. * Documentation updates: - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). - - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/narrow_and_mixed_precision_gemms.md) + - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md) - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). -Note: CUTLASS 3.x builds are known to be broken on Windows platforms for all CUDA toolkits. +Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. **See the [CHANGELOG](CHANGELOG.md) for details of all past releases and updates.** @@ -162,7 +162,7 @@ We have tested the following environments. Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. -Note: CUTLASS 3.x builds are known to be broken on Windows platforms for all CUDA toolkits. +Note: CUTLASS 3.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. ## Hardware diff --git a/customConfigs.cmake b/customConfigs.cmake new file mode 100644 index 00000000..c86e15be --- /dev/null +++ b/customConfigs.cmake @@ -0,0 +1,92 @@ +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + + + +# Profiler based functional testing +set(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS OFF CACHE BOOL "Utilize profiler-based functional regressions") +set(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL ${CUTLASS_TEST_LEVEL} CACHE STRING "Profiler functional regression test level") + +find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) + +function(cutlass_generate_kernel_filter_and_testlists_files) + + set(options) + set(oneValueArgs TEST_SET_NAME) + set(multiValueArgs) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + execute_process( + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR} + ${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py + --generator-target=${__TEST_SET_NAME} + --cuda-version=${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION} + --architectures=${CUTLASS_NVCC_ARCHS} + --kernels=\* + --disable-cutlass-package-imports + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + RESULT_VARIABLE cutlass_FILTER_GENERATION_RESULT + OUTPUT_VARIABLE cutlass_FILTER_GENERATION_OUTPUT + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log + ) + + if(NOT cutlass_FILTER_GENERATION_RESULT EQUAL 0) + message(FATAL_ERROR "Error generating kernel filters and testlists files. See ${CMAKE_CURRENT_BINARY_DIR}/library_filter_generation.log") + endif() +endfunction() + +if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS) + + set(PROFILER_ARCH_LIST 100a) + foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS) + if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST)) + message(FATAL_ERROR "Only SM100a compute capability is supported with profiler-based unit tests") + endif() + endforeach() + + if(CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 0) + + message(STATUS "Building for L0 profiler-based functional regressions") + cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l0) + set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set") + set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L0_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set") + + elseif (CUTLASS_PROFILER_REGRESSION_TEST_LEVEL EQUAL 1) + + message(STATUS "Building for L1 profiler-based functional regressions") + cutlass_generate_kernel_filter_and_testlists_files(TEST_SET_NAME kernel_testlist_l1) + set(KERNEL_FILTER_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm_kernel_filter.list CACHE STRING "Kernel set") + set(CUTLASS_PROFILER_REGRESSION_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/FK_functional_L1_testlist_SM${CUTLASS_NVCC_ARCHS}_cutlass3x_gemm.csv CACHE STRING "Regression set") + + endif() +endif() + + diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index 97e9061e..3a35cd71 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -483,18 +483,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { - std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; - return 0; - } - if (props.major != 9 || props.minor != 0) { std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + + // diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu index f250e4b9..8f4b8758 100644 --- a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu +++ b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu @@ -566,17 +566,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } // diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu index bdb59bfd..6fdcc836 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -103,11 +103,10 @@ #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/mixed_dtype_utils.hpp" #include "helper.h" #include "mixed_dtype_utils.hpp" -#include "packed_scale.hpp" -#include "reorder_utils.hpp" using namespace cute; @@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t; using ValueShuffle = Layout, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7] int constexpr NumShuffleAtoms = 1; using MmaAtomShape = Layout>>; -using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); -using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); +using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); using ElementScale = MmaType; using ElementZero = ElementScale; @@ -438,14 +437,15 @@ void initialize(Options const& options) { auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); - auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref); - dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + cudaStream_t stream = cudaStreamDefault; + cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream); if (options.shuffle) { // Repeat the reorder layout atom to tile the whole tensor shape - layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); - reorder_tensor(block_B.get(), layout_B, layout_B_reordered); + layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + cutlass::reorder_tensor(block_B.get(), layout_B, layout_B_reordered); print("Quantized tensor layout: "); print(layout_B_reordered); @@ -613,17 +613,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } // diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu index 581ccf88..cc540803 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -107,11 +107,10 @@ #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/mixed_dtype_utils.hpp" #include "helper.h" #include "mixed_dtype_utils.hpp" -#include "packed_scale.hpp" -#include "reorder_utils.hpp" using namespace cute; @@ -144,8 +143,8 @@ using StrideB = cutlass::detail::TagToStrideB_t; // Define the CuTe layout for reoredered quantized tensor B // LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. // It specifies the reordering within a single warp's fragment -using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); -using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); +using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); using ElementScale = MmaType; using ElementZero = ElementScale; // only for verify @@ -349,10 +348,10 @@ void initialize(Options const& options) { initialize_tensor(block_A, seed + 2022); initialize_quant_tensor(block_B, seed + 2021); - unify_quant_encoding(block_B, block_B_modified); + cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size()); initialize_tensor(block_C, seed + 2020); initialize_scale(block_scale, options); - initialize_packed_scale(block_scale, block_scale_packed); + cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size()); initialize_zero(block_zero, options); auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); @@ -360,12 +359,13 @@ void initialize(Options const& options) { stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); - dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + cudaStream_t stream = cudaStreamDefault; + cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream); if (options.shuffle) { // Repeat the reorder layout atom to tile the whole tensor shape - layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); - reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered); + layout_B_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + cutlass::reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered); print("Quantized tensor layout: "); print(layout_B_reordered); @@ -518,17 +518,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } // diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu index 2629ee33..aa114e74 100644 --- a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -100,6 +100,7 @@ #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/mixed_dtype_utils.hpp" #include "helper.h" #include "mixed_dtype_utils.hpp" @@ -322,9 +323,10 @@ void initialize(MixedDtypeOptions const& options) { auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); - auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + auto layout_scale_zero = cute::make_layout(shape_scale_zero, stride_S_ref); - dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + cudaStream_t stream = cudaStreamDefault; + cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g, stream); } /// Populates a Gemm::Arguments structure from the given commandline options @@ -483,17 +485,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } // diff --git a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp index 724adabc..f3dd9058 100644 --- a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp +++ b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp @@ -60,8 +60,8 @@ struct MixedDtypeOptions { float alpha = 1.0f; float beta = 0.0f; - int iterations = 1000; - int warmup = 1000; + int iterations = 100; + int warmup = 10; int mode = 1; int m = 5120, n = 4096, k = 4096; int g = 128; @@ -228,22 +228,18 @@ bool initialize_scale( MixedDtypeOptions const& options, uint64_t seed = 2023) { - if (options.mode == MixedDtypeGemmMode::ConvertOnly) { - // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. - std::vector stage(block.size(), Element(1.0f)); - block.copy_from_host(stage.data()); - } - else { + // If no scales, initialize with 1 so we can use the same kernel to dequantize the data + float scope_max = 1.0f, scope_min = 1.0f; + if (options.mode != MixedDtypeGemmMode::ConvertOnly) { float elt_max_f = float(cutlass::platform::numeric_limits::max()); const float max_dequant_val = 4.f; const float min_dequant_val = 0.5f; - - float scope_max(max_dequant_val / elt_max_f); - float scope_min(min_dequant_val / elt_max_f); - - cutlass::reference::device::BlockFillRandomUniform( - block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + scope_max = max_dequant_val / elt_max_f; + scope_min = min_dequant_val / elt_max_f; } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + return true; } @@ -253,139 +249,14 @@ bool initialize_zero( MixedDtypeOptions const& options, uint64_t seed = 2023) { + // If no bias, initialize with 0 so we can use the same kernel to dequantize the data + float scope_max = 0.0f, scope_min = 0.0f; if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { - cutlass::reference::device::BlockFillRandomUniform( - block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); - } else { - // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. - std::vector stage(block.size(), Element(0.0f)); - block.copy_from_host(stage.data()); + scope_max = 2.0f; + scope_min = -2.0f; } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + return true; } - -/// Dequantize the weights for verification - -template -__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer, - QuantizedElement const* q_buffer, - OperandLayout const operand_layout, - ElementScale const* scale_buffer, - ElementZero const* zero_buffer, - ScaleBroadCastLayout const broadcasted_scale_layout, - ThrLayout thr_layout) { - using namespace cute; - - // Represent the full tensors to gmem elements. - // These are expected to have shape [MN, K, L] - cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); - auto init_quantized_iterator = [&]() { - if constexpr (cute::sizeof_bits_v >= 8) { - return cute::make_gmem_ptr(q_buffer); - } else { - return cute::subbyte_iterator(q_buffer); - } - }; - cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout); - // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting - // It is expected that K % G == 0 - cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); - cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); - - // Assign 1 thread per element in the thread block - auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); // - auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) - - // Tile across the block - auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); - auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); - auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); - auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); - - auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); - auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); - auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); - auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); - - // Make a fragment of registers to hold gmem loads - cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); - cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); - cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); - cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); - cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); - cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); - - cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); - auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); - auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); - - const auto num_iters = cute::size<3>(tOpDq_gOpDq); - - for (int ii = 0; ii < num_iters; ++ii) { - const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); - if (thread_offset < cute::size<0>(operand_layout)) { - cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); - cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); - cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); - cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); - cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); - cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{}); - cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{}); - cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); - cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); - } - } -} - -template -void dequantize_weight(DequantizedElement* dq_buffer, - QuantizedElement const* q_buffer, - OperandLayout const operand_layout, - ElementScale const* scale_buffer, - ElementZero const* zero_buffer, - ScaleLayout const scale_layout, - int const group_size) { - - using namespace cute; - - constexpr int tpb = 128; - auto thr_layout = make_layout(make_shape(Int{})); - - const auto num_rows = get<0>(shape(operand_layout)); - const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] - const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] - const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] - - if (num_rows != size<0>(scale_layout)) { - std::cerr << "Invalid first dimension for scales. Must match first dim for weights." - << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) - << std::endl; - exit(-1); - } - - const auto scale_stride0 = get<0>(stride(scale_layout)); - const auto scale_stride1 = get<1>(stride(scale_layout)); - const auto scale_stride2 = get<2>(stride(scale_layout)); - - auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); - auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); - auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); - - const auto blocks_x = gemm_k; - const auto blocks_y = batches; - - dim3 blocks(blocks_x, blocks_y, 1); - dequantize_weight_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); - CUDA_CHECK(cudaDeviceSynchronize()); -} diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp deleted file mode 100644 index a595ca72..00000000 --- a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp +++ /dev/null @@ -1,211 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include - -#include "cutlass/util/device_memory.h" -#include "cutlass/integer_subbyte.h" -#include "cutlass/float8.h" -#include "cutlass/util/reference/device/tensor_fill.h" - -#include "cute/tensor.hpp" -#include "cute/util/type_traits.hpp" - -namespace cutlass -{ -template -class packed_scale_t { -public: - static_assert(cute::is_same_v || - cute::is_same_v || - cute::is_same_v || - cute::is_same_v, - "only 8 bit arithmetic types are supported."); - CUTLASS_HOST_DEVICE - explicit packed_scale_t(T val) { - if constexpr (!cute::is_unsigned_v) { - // Only pack negative values. The positive values are generated in flight in the mainloop. - storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); - storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); - } - else { - storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); - storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); - } - } - CUTLASS_HOST_DEVICE - packed_scale_t() = default; - CUTLASS_HOST_DEVICE - explicit operator float() const { - return float(get()); - } - CUTLASS_HOST_DEVICE - bool operator==(packed_scale_t const& rhs) const { - return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; - } - CUTLASS_HOST_DEVICE - bool operator!=(packed_scale_t const& rhs) const { - return !(*this == rhs); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() + rhs.get()); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() - rhs.get()); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() * rhs.get()); - } - CUTLASS_HOST_DEVICE - friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { - return packed_scale_t(lhs.get() / rhs.get()); - } - -private: - using Storage = uint32_t; - using Stage = uint8_t; - - Storage storage[2] {}; - - CUTLASS_HOST_DEVICE - static Storage pack4(T c1, T c2, T c3, T c4) { - Storage result = 0; - result |= (static_cast(reinterpret_cast(c4)) << 24); - result |= (static_cast(reinterpret_cast(c3)) << 16); - result |= (static_cast(reinterpret_cast(c2)) << 8); - result |= static_cast(reinterpret_cast(c1)); - return result; - } - CUTLASS_HOST_DEVICE - T get() const { - auto stage = static_cast(storage[0] >> 8); - #if defined(__CUDA_ARCH__) - return reinterpret_cast(stage); - #else - T tmp; - std::memcpy(&tmp, &stage, sizeof(Stage)); - return tmp; - #endif - } - CUTLASS_HOST_DEVICE - T get(int idx) const { - Stage stage; - if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); - else stage = static_cast(storage[1] >> (8 * idx - 32)); - #if defined(__CUDA_ARCH__) - return reinterpret_cast(stage); - #else - T tmp; - std::memcpy(&tmp, &stage, sizeof(Stage)); - return tmp; - #endif - } -}; -} - -/// Helpers to initialize scale lookup table - -// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. -// Here the encodings of positive values and negative values are unified (except for the sign bit). -// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). -bool unify_quant_encoding( - cutlass::DeviceAllocation const& block_in, - cutlass::DeviceAllocation& block_out) { - - using StorageType = cutlass::int4b_t::Storage; - - if (block_in.size() != block_out.size()) { - std::cerr << "block_in and block_out must have same size.\n"; - return false; - } - constexpr int pack = cute::sizeof_bits_v / 4; - std::vector data(block_in.size() / pack); - cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack); - - for (auto&& d : data) { - StorageType out = 0; - StorageType mask = 0x0f; - for (int i = 0; i < pack; ++i) { - cutlass::int4b_t curr; - curr.storage = (d >> (i * 4)) & 0x0f; - switch (curr) { - case 1: curr.storage = StorageType(0b0111); break; // 2's complement - case 2: curr.storage = StorageType(0b0110); break; // 2's complement - case 3: curr.storage = StorageType(0b0101); break; // 2's complement - case 4: curr.storage = StorageType(0b0100); break; // 2's complement - case 5: curr.storage = StorageType(0b0011); break; // 2's complement - case 6: curr.storage = StorageType(0b0010); break; // 2's complement - case 7: curr.storage = StorageType(0b0001); break; // 2's complement - default: break; - } - out |= (curr.storage << (4 * i)) & mask; - mask <<= 4; - } - d = out; - } - - cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack); - return true; -} - -template -bool initialize_packed_scale( - cutlass::DeviceAllocation const& block_in, - cutlass::DeviceAllocation > & block_out) { - - std::vector data_in(block_in.size()); - std::vector > data_out(block_in.size()); - try { - block_in.copy_to_host(data_in.data()); - } catch (cutlass::cuda_exception const& e) - { - std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; - return false; - } - for (size_t i = 0; i < block_in.size(); ++i) - { - cutlass::packed_scale_t tmp(data_in[i]); - data_out[i] = reinterpret_cast const&>(tmp); - } - try { - block_out.copy_from_host(data_out.data()); - } catch (cutlass::cuda_exception const& e) - { - std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; - return false; - } - return true; -} diff --git a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp deleted file mode 100644 index 0f4e38d6..00000000 --- a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp +++ /dev/null @@ -1,162 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cute/arch/mma_sm90.hpp" - -#include "cutlass/util/device_memory.h" - -// Given a type of MMA instruction, compute a memory reordering atom that places all values -// owned by each thread in contiguous memory locations. This improves smem load vectorization, -// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order -// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. -// In addition, we can reorder the values across several MMA instructions to get even wider -// vectorization (AtomLayout parameter) and permute the values within each instruction to get -// more optimal conversion instruction sequences (ValLayout parameter). -template, - class ValLayout = cute::Layout> -constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) -{ - using namespace cute; - - static_assert(is_static_v, "ValLayout must be static"); - static_assert(is_static_v, "AtomLayout must be static"); - - // 1. Choose an MMA atom to access TV layout and MN shape - // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary - using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); - using MmaTraits = MMA_Traits; - auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); - auto tv_layout_mma = typename MmaTraits::ALayout{}; - static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); - - // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) - // Note: this assumes A is partitioned between warps along M mode - auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); - auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); - auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); - auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); - - // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization - auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); - - // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) - auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); - auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); - auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); - auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); - - return layout_atom; -} - -template -__global__ void reorder_tensor_kernel( - cute::Tensor S, - cute::Tensor D, - TiledCopy tiled_copy) -{ - using namespace cute; - - using T = typename EngineDst::value_type; - - Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); - Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); - - auto thread_copy = tiled_copy.get_slice(threadIdx.x); - Tensor tS = thread_copy.partition_S(gS); - Tensor tD = thread_copy.partition_D(gD); - - copy(tiled_copy, tS, tD); -} - -template -void reorder_tensor( - cute::Tensor S, - cute::Tensor D) -{ - using namespace cute; - - using T = typename EngineDst::value_type; - static_assert(is_same_v, T>, "Type mismatch"); - - // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread - // This avoids a race condition when writing out subbyte types (e.g. int4b_t). - auto has_major_mode = [](auto s) { - return any_of(s, [](auto a){ return is_constant<1, decltype(a)>{}; }); - }; - static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), - "Could not find stride-1 mode in destination layout"); - constexpr int N = shape_div(Int<8>{}, sizeof_bits{}); - auto val_layout = conditional_return(LayoutDst{}))>( - make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), - make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); - - // Make a tiled copy with a simple row-major thread order and above layout - int constexpr NumThreads = 128; - auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); - auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); - - // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper - using TileShape = Shape<_16>; - auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); - dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; - - reorder_tensor_kernel<<>>(S, D, tiled_copy); - CUDA_CHECK(cudaDeviceSynchronize()); -} - -// In-place version -template -void reorder_tensor( - T const* src, - LayoutSrc const& layout_src, - T * dst, - LayoutDst const& layout_dst) -{ - using namespace cute; - reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), - make_tensor(make_gmem_ptr(dst), layout_dst)); -} - -// In-place version -template -void reorder_tensor( - T * data, - LayoutSrc const& layout_src, - LayoutDst const& layout_dst) -{ - using namespace cute; - cutlass::DeviceAllocation temp(size(layout_src)); - reorder_tensor(data, layout_src, temp.get(), layout_dst); - cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); -} diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu index ec29bc05..4f77ae03 100644 --- a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -513,17 +513,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } // diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu index 0f014cc7..6cedb599 100644 --- a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -731,17 +731,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } // diff --git a/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu b/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu index cdf94c01..d84934ac 100644 --- a/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu +++ b/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu @@ -768,16 +768,26 @@ int main(int argc, char const** argv) { return -1; } - if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4) || - (props.major != 8 && props.minor != 9)) { + bool satisfied; + if (props.major < 10) { + // Pre-Blackwell + satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4); + satisfied &= (props.major > 8) || (props.major == 8 && props.minor == 9); + } + else { + satisfied = (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8); + } + if (!satisfied) { // - // This example requires an NVIDIA Ada-architecture GPU. + // This example requires an NVIDIA GPU with compute capability 8.9 or greater. // std::cout - << "CUTLASS's FP8 SM89 example requires a GPU of NVIDIA's Ada architecture " - << "and CUDA toolkit version 12.4 or later.\n"; + << "CUTLASS's FP8 SM89 example requires an NVIDIA GPU with compute capability 8.9 or greater " + << "and CUDA toolkit version 12.4 or later" + << " (12.8 or later needed for SM100+)" + << std::endl; return 0; } diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu index a5f59ca2..62da02c0 100644 --- a/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu +++ b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu @@ -504,17 +504,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } // diff --git a/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu index 708d1db6..da057e2d 100644 --- a/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu +++ b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu @@ -570,18 +570,13 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { - std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; - return 0; - } - if (props.major != 9 || props.minor != 0) { std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } + + // diff --git a/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu index 03b54f3e..9fcb9dee 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu +++ b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu @@ -469,18 +469,12 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major < 9) { + if (props.major != 9 || props.minor != 0) { std::cerr - << "This example requires a GPU of NVIDIA's Hopper Architecture or " - << "later (compute capability 90 or greater).\n"; + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; return 0; } - - else if (props.major != 9 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; - return 0; - } - + // // Parse options diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index d6de7f89..0c407d34 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -31,28 +31,20 @@ /*! \file \brief Grouped scale Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture - This example demonstrate a grouped scaled FP8 GEMM using the new CUTLASS 3.0. APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: - 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) which are more efficient than the Ampere tensor core instructions. - 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous copies between thread blocks in a cluster. - 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). - 4. This example shows all important fusions used by FP8 gemm kernels, i.e., grouped scale factor along M for A, blocked scale factor along K for A tensor, blocked scale factor for B tensor, the abs_max value of D tensor. - 5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can improve performance. - Examples: - $ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \ --m=2816 --n=3072 --k=16384 \ --save_aux=false --save_amax=false \ diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt index 3453ed40..b22b281f 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt @@ -34,4 +34,4 @@ cutlass_example_add_executable( cutlass_example_add_executable( 67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling 67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu - ) \ No newline at end of file + ) diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h index 652f72c1..cb3ff022 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h @@ -191,7 +191,7 @@ void gett_mainloop( static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - + using cute::raw_pointer_cast; using ElementA = typename ElementTraits::type; diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu new file mode 100644 index 00000000..b22d8305 --- /dev/null +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief + NOTE: Write docu +*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "helper.h" +#include "grouped_mixed_dtype_utils.hpp" + +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using MmaType = cutlass::bfloat16_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = MmaType; +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +// Need to pass a pointer type to make the 3rd dimension of Stride be _0 +using StrideA = cute::remove_pointer_t>; +using StrideB = cute::remove_pointer_t>; + +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. +// It specifies the reordering within a single warp's fragment +// using ValueShuffle = Layout<_1>; // no value reordering +using ValueShuffle = Layout, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7] +int constexpr NumShuffleAtoms = 1; +using MmaAtomShape = Layout>>; +using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout>, StrideB>{})); + +using ElementZero = cutlass::bfloat16_t; +using ElementScale = cutlass::bfloat16_t; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_16,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, typename cutlass::layout::LayoutTranspose::type *, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementB, LayoutB_Transpose *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopConvertOnly, + CollectiveEpilogue +>; + +using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using CollectiveMainloopConvertOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementB, LayoutB_Reordered *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelConvertOnlyShuffled = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopConvertOnlyShuffled, + CollectiveEpilogue +>; + +using GemmConvertOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using CollectiveMainloopScaleOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Reordered *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnlyShuffled = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopScaleOnlyShuffled, + CollectiveEpilogue +>; + +using GemmScaleOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideC = typename GemmKernelConvertOnly::InternalStrideC; +using StrideD = typename GemmKernelConvertOnly::InternalStrideD; +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_B_dq; +std::vector offset_C; +std::vector offset_D; +std::vector offset_scale; +std::vector offset_zero; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector stride_C_host_ref; +std::vector stride_D_host_ref; +std::vector stride_S_host; +std::vector stride_S_host_ref; + +std::vector alpha_host; +std::vector beta_host; + +uint64_t seed = 2020; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_B_dq; +cutlass::DeviceAllocation ptr_scale; +cutlass::DeviceAllocation ptr_zero; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation layout_B_reordered; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation stride_C_ref; +cutlass::DeviceAllocation stride_D_ref; +cutlass::DeviceAllocation stride_S_ref; +cutlass::DeviceAllocation stride_S; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options : GroupedMixedDtypeOptions { + using Base = GroupedMixedDtypeOptions; + + bool shuffle = true; + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + cmd.get_cmd_line_argument("shuffle", shuffle); + + this->Base::parse(argc, args); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "69_hopper_int4_bf16_grouped_gemm\n\n" + << " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --warmup= Number of warmup iterations to perform\n\n" + << " --shuffle= Enable the offline layout swizzling.\n\n" + << " --benchmark= Executes a benchmark problem size.\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "69_hopper_int4_bf16_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n"; + + return out; + } +}; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Allocates device-side data +void allocate(Options const& options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_B_dq = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_elements_scale = 0; + int64_t total_elements_zero = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + const int scale_k = 1; + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); + offset_B_dq.push_back(total_elements_B_dq); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_scale.push_back(total_elements_scale); + offset_zero.push_back(total_elements_zero); + + int64_t elements_A = M * K; + int64_t elements_B = K * N ; + int64_t elements_B_dq = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + int64_t elements_scale = scale_k * N; + int64_t elements_zero = scale_k * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_B_dq += elements_B_dq; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_scale += elements_scale; + total_elements_zero += elements_zero; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1})); + stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1})); + stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1})); + stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1})); + stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1})); + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_B_dq.reset(total_elements_B_dq); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_scale.reset(total_elements_scale); + block_zero.reset(total_elements_zero); + + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_B_dq_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_scale_host(options.groups); + std::vector ptr_zero_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i); + ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_B_dq.reset(options.groups); + ptr_B_dq.copy_from_host(ptr_B_dq_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_scale.reset(options.groups); + ptr_scale.copy_from_host(ptr_scale_host.data()); + + ptr_zero.reset(options.groups); + ptr_zero.copy_from_host(ptr_zero_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + stride_C_ref.reset(options.groups); + stride_C_ref.copy_from_host(stride_C_host_ref.data()); + + stride_D_ref.reset(options.groups); + stride_D_ref.copy_from_host(stride_D_host_ref.data()); + + stride_S_ref.reset(options.groups); + stride_S_ref.copy_from_host(stride_S_host_ref.data()); + + stride_S.reset(options.groups); + stride_S.copy_from_host(stride_S_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_tensor(block_A, seed + 2023); + initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_C, seed + 2021); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + + for (int32_t i = 0; i < options.groups; ++i) { + const int scale_k = 1; + auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{}); + auto shape_scale = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), scale_k, Int<1>{}); + auto layout_B = make_layout(shape_B, stride_B_host.at(i)); + auto layout_scale = make_layout(shape_scale, stride_S_host_ref.at(i)); + cudaStream_t stream = cudaStreamDefault; + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale, options.k, stream); + } + + problem_sizes.reset(options.groups); + + if (options.shuffle) { + std::vector layout_B_reordered_host(options.groups); + for (int32_t i = 0; i < options.groups; ++i) { + auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{}); + auto layout_B = make_layout(shape_B, stride_B_host.at(i)); + // Repeat the reorder layout atom to tile the whole tensor shape + layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B); + cutlass::reorder_tensor(block_B.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]); + if (i == 0) { + print("Quantized tensor layout: "); + print(layout_B_reordered_host[0]); + print("\n"); + } + } + layout_B_reordered.reset(options.groups); + layout_B_reordered.copy_from_host(layout_B_reordered_host.data()); + } + + // Reverse MN -> NM for SwapAB + for (int32_t i = 0; i < options.groups; ++i) { + auto [M, N, K] = options.problem_sizes_host[i]; + options.problem_sizes_host[i] = make_tuple(N, M, K); + } + problem_sizes.copy_from_host(options.problem_sizes_host.data()); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true) +{ + using Args = typename Gemm::Arguments; + auto&& dB = [&]() { + // NOTE: add GemmScaleWithZeroPointShuffled + if constexpr (cute::is_same_v || + cute::is_same_v) { + // offline swizzling is enabled. + return layout_B_reordered.get(); + } + else { + return stride_B.get(); + } + }(); + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + Args arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + + if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) { + arguments = Args { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_B.get(), dB, ptr_A.get(), stride_A.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } + else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) { + arguments = Args { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } + else { + std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; + exit(-1); + } + return arguments; +} + +bool verify(Options const& options) { + bool passed = true; + + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; + + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA_verif = typename GemmRef::GemmKernel::StrideA; + using StrideB_verif = typename GemmRef::GemmKernel::StrideB; + using StrideC_verif = typename GemmRef::GemmKernel::StrideC; + using StrideD_verif = typename GemmRef::GemmKernel::StrideD; + + const ElementD epsilon(1e-2f); + const ElementD non_zero_floor(1e-4f); + + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto N = get<0>(problem); + auto M = get<1>(problem); + auto K = get<2>(problem); + if (M == 0) { + continue; + } + else { + StrideA_verif stride_A_verif; + StrideB_verif stride_B_verif; + + stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1)); + stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1)); + + // + // Compute reference output + // + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + {block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif}, + {{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); + std::cout << "Group: " << i << " Status: " << passed << std::endl; + } + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + std::cout << "We passed all checks\n"; + // Check if output from CUTLASS kernel and reference kernel are equal or not + MixedDtypeResult result; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host); + if (!result.passed) { + exit(-1); + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 9 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + std::cout << "Running in no scale mode." << std::endl; + if (options.shuffle) { + std::cout << "Offline shuffle enabled." << std::endl; + run(options, false); + } else { + std::cout << "Offline shuffle disabled." << std::endl; + run(options, false); + } + } + else if (options.mode == MixedDtypeGemmMode::ScaleOnly) { + std::cout << "Running in per-column scale mode." << std::endl; + if (options.shuffle) { + std::cout << "Offline shuffle enabled." << std::endl; + run(options, false); + } else { + std::cout << "Offline shuffle disabled." << std::endl; + run(options, false); + } + } +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu new file mode 100644 index 00000000..cc0494ec --- /dev/null +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu @@ -0,0 +1,753 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief + NOTE: Write docu +*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "helper.h" +#include "grouped_mixed_dtype_utils.hpp" + +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = MmaType; +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +// Need to pass a pointer type to make the 3rd dimension of Stride be _0 +using StrideA = cute::remove_pointer_t>; +using StrideB = cute::remove_pointer_t>; + +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. +// It specifies the reordering within a single warp's fragment +using LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape(LayoutAtomQuant{}, Layout>, StrideB>{})); + +using ElementZero = cutlass::float_e4m3_t; +using ElementScale = cutlass::float_e4m3_t; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_16,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, typename cutlass::layout::LayoutTranspose::type *, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, LayoutB_Transpose *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, LayoutB_Reordered *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopShuffled, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; +using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideC = typename GemmKernelScaleOnly::InternalStrideC; +using StrideD = typename GemmKernelScaleOnly::InternalStrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_B_dq; +std::vector offset_C; +std::vector offset_D; +std::vector offset_scale; +std::vector offset_zero; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector stride_C_host_ref; +std::vector stride_D_host_ref; +std::vector stride_S_host; +std::vector stride_S_host_ref; + +std::vector alpha_host; +std::vector beta_host; + +uint64_t seed = 2020; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_modified; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation> block_scale_packed; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_B_dq; +cutlass::DeviceAllocation *> ptr_scale_packed; +cutlass::DeviceAllocation ptr_zero; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation layout_B_reordered; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation stride_C_ref; +cutlass::DeviceAllocation stride_D_ref; +cutlass::DeviceAllocation stride_S_ref; +cutlass::DeviceAllocation stride_S; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options : GroupedMixedDtypeOptions { + using Base = GroupedMixedDtypeOptions; + + bool shuffle = true; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + cmd.get_cmd_line_argument("shuffle", shuffle); + + this->Base::parse(argc, args); + + mode = 1; // override the mode value to always be scale only mode + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "69_hopper_int4_fp8_grouped_gemm\n\n" + << " Hopper Mixed Dtype Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --c= The size of each chunk for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --warmup= Number of warmup iterations to perform\n\n" + << " --shuffle= Enable the offline layout swizzling.\n\n" + << " --benchmark= Executes a benchmark problem size.\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "69_hopper_int4_fp8_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=1 --beta=0 \n\n"; + + return out; + } +}; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). + +/// Allocates device-side data +void allocate(Options const& options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_B_dq = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_elements_scale = 0; + int64_t total_elements_zero = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + const int scale_k = 1; + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); + offset_B_dq.push_back(total_elements_B_dq); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_scale.push_back(total_elements_scale); + offset_zero.push_back(total_elements_zero); + + int64_t elements_A = M * K; + int64_t elements_B = K * N ; + int64_t elements_B_dq = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + int64_t elements_scale = scale_k * N; + int64_t elements_zero = scale_k * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_B_dq += elements_B_dq; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_scale += elements_scale; + total_elements_zero += elements_zero; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1})); + stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1})); + stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1})); + stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1})); + stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1})); + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_B_modified.reset(total_elements_B); + block_B_dq.reset(total_elements_B_dq); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_scale.reset(total_elements_scale); + block_scale_packed.reset(total_elements_scale); + block_zero.reset(total_elements_zero); + + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options& options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_B_dq_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector *> ptr_scale_packed_host(options.groups); + std::vector ptr_zero_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B_modified.get() + offset_B.at(i); + ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + ptr_scale_packed_host.at(i) = block_scale_packed.get() + offset_scale.at(i); + ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_B_dq.reset(options.groups); + ptr_B_dq.copy_from_host(ptr_B_dq_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_scale_packed.reset(options.groups); + ptr_scale_packed.copy_from_host(ptr_scale_packed_host.data()); + + ptr_zero.reset(options.groups); + ptr_zero.copy_from_host(ptr_zero_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + stride_C_ref.reset(options.groups); + stride_C_ref.copy_from_host(stride_C_host_ref.data()); + + stride_D_ref.reset(options.groups); + stride_D_ref.copy_from_host(stride_D_host_ref.data()); + + stride_S_ref.reset(options.groups); + stride_S_ref.copy_from_host(stride_S_host_ref.data()); + + stride_S.reset(options.groups); + stride_S.copy_from_host(stride_S_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_tensor(block_A, seed + 2023); + initialize_quant_tensor(block_B, seed + 2022); + cutlass::unified_encode_int4b(block_B.get(), block_B_modified.get(), block_B.size()); + initialize_tensor(block_C, seed + 2021); + initialize_scale(block_scale, options); + cutlass::pack_scale_fp8(block_scale.get(), block_scale_packed.get(), block_scale.size()); + initialize_zero(block_zero, options); + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + problem_sizes.reset(options.groups); + + if (options.shuffle) { + std::vector layout_B_reordered_host(options.groups); + for (int32_t i = 0; i < options.groups; ++i) { + auto shape_B = cute::make_shape(cute::get<1>(options.problem_sizes_host[i]), cute::get<2>(options.problem_sizes_host[i]), Int<1>{}); + auto layout_B = make_layout(shape_B, stride_B_host.at(i)); + // Repeat the reorder layout atom to tile the whole tensor shape + layout_B_reordered_host[i] = tile_to_shape(LayoutAtomQuant{}, shape_B); + cutlass::reorder_tensor(block_B_modified.get() + offset_B.at(i), layout_B, layout_B_reordered_host[i]); + if (i == 0) { + print("Quantized tensor layout: "); + print(layout_B_reordered_host[0]); + print("\n"); + } + } + layout_B_reordered.reset(options.groups); + layout_B_reordered.copy_from_host(layout_B_reordered_host.data()); + } + + // Reverse MN -> NM for SwapAB + for (int32_t i = 0; i < options.groups; ++i) { + auto [M, N, K] = options.problem_sizes_host[i]; + options.problem_sizes_host[i] = make_tuple(N, M, K); + } + problem_sizes.copy_from_host(options.problem_sizes_host.data()); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true) +{ + using Args = typename Gemm::Arguments; + auto&& dB = [&]() { + if constexpr (cute::is_same_v) { // offline swizzling is enabled. + return layout_B_reordered.get(); + } + else { + return stride_B.get(); + } + }(); + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + Args arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + arguments = Args { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_B.get(), dB, ptr_A.get(), stride_A.get(), ptr_scale_packed.get(), stride_S.get(), options.k}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + return arguments; +} + + +bool verify(Options const& options) { + bool passed = true; + + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA_verif = typename GemmRef::GemmKernel::StrideA; + using StrideB_verif = typename GemmRef::GemmKernel::StrideB; + using StrideC_verif = typename GemmRef::GemmKernel::StrideC; + using StrideD_verif = typename GemmRef::GemmKernel::StrideD; + + const ElementD epsilon(1e-2f); + const ElementD non_zero_floor(1e-4f); + + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto N = get<0>(problem); + auto M = get<1>(problem); + auto K = get<2>(problem); + if (M == 0) { + continue; + } + else { + StrideA_verif stride_A_verif; + StrideB_verif stride_B_verif; + + stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1)); + stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1)); + + const int scale_k = 1; + auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i)); + auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i)); + cudaStream_t stream = cudaStreamDefault; + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream); + + // + // Compute reference output + // + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + {block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif}, + {{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); + std::cout << "Group: " << i << " Status: " << passed << std::endl; + } + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + std::cout << "We passed all checks\n"; + // Check if output from CUTLASS kernel and reference kernel are equal or not + MixedDtypeResult result; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host); + if (!result.passed) { + exit(-1); + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 9 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "Running in per-column scale mode." << std::endl; + if (options.shuffle) { + std::cout << "Offline shuffle enabled." << std::endl; + run(options, false); + } else { + std::cout << "Offline shuffle disabled." << std::endl; + run(options, false); + } +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu new file mode 100644 index 00000000..883d8cbf --- /dev/null +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu @@ -0,0 +1,678 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + NOTE: Write docu +*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "helper.h" +#include "grouped_mixed_dtype_utils.hpp" + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using MmaType = cutlass::bfloat16_t; +using QuantType = cutlass::float_e5m2_t; +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using ElementZero = cutlass::bfloat16_t; +using ElementScale = cutlass::bfloat16_t; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_16,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; +using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; // Epilogue to launch + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, typename cutlass::layout::LayoutTranspose::type *, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type *, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementB, LayoutB_Transpose *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopConvertOnly, + CollectiveEpilogue +>; + +using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose *, AlignmentB, + ElementA, LayoutA_Transpose *, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename GemmConvertOnly::GemmKernel::InternalStrideA; +using StrideB = typename GemmConvertOnly::GemmKernel::InternalStrideB; +using StrideC = typename GemmConvertOnly::GemmKernel::InternalStrideC; +using StrideD = typename GemmConvertOnly::GemmKernel::InternalStrideD; +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_B_dq; +std::vector offset_C; +std::vector offset_D; +std::vector offset_scale; +std::vector offset_zero; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector stride_C_host_ref; +std::vector stride_D_host_ref; +std::vector stride_S_host; +std::vector stride_S_host_ref; + +std::vector alpha_host; +std::vector beta_host; + +uint64_t seed = 2020; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_B_dq; +cutlass::DeviceAllocation ptr_scale; +cutlass::DeviceAllocation ptr_zero; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation stride_C_ref; +cutlass::DeviceAllocation stride_D_ref; +cutlass::DeviceAllocation stride_S_ref; +cutlass::DeviceAllocation stride_S; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using Options = GroupedMixedDtypeOptions; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Allocates device-side data +void allocate(Options const& options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_B_dq = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_elements_scale = 0; + int64_t total_elements_zero = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + const int scale_k = 1; + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B * cutlass::sizeof_bits::value / 8); + offset_B_dq.push_back(total_elements_B_dq); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_scale.push_back(total_elements_scale); + offset_zero.push_back(total_elements_zero); + + int64_t elements_A = M * K; + int64_t elements_B = K * N ; + int64_t elements_B_dq = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + int64_t elements_scale = scale_k * N; + int64_t elements_zero = scale_k * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_B_dq += elements_B_dq; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_scale += elements_scale; + total_elements_zero += elements_zero; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {N, M, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {N, M, 1})); + stride_C_host_ref.push_back(cutlass::make_cute_packed_stride(StrideC_ref{}, {M, N, 1})); + stride_D_host_ref.push_back(cutlass::make_cute_packed_stride(StrideD_ref{}, {M, N, 1})); + stride_S_host_ref.push_back(cutlass::make_cute_packed_stride(StrideS_ref{}, {N, scale_k, 1})); + stride_S_host.push_back(cutlass::make_cute_packed_stride(StrideS{}, {N, scale_k, 1})); + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_B_dq.reset(total_elements_B_dq); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_scale.reset(total_elements_scale); + block_zero.reset(total_elements_zero); + + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_B_dq_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_scale_host(options.groups); + std::vector ptr_zero_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B_dq.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + ptr_scale_host.at(i) = block_scale.get() + offset_scale.at(i); + ptr_zero_host.at(i) = block_zero.get() + offset_zero.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_B_dq.reset(options.groups); + ptr_B_dq.copy_from_host(ptr_B_dq_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_scale.reset(options.groups); + ptr_scale.copy_from_host(ptr_scale_host.data()); + + ptr_zero.reset(options.groups); + ptr_zero.copy_from_host(ptr_zero_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + stride_C_ref.reset(options.groups); + stride_C_ref.copy_from_host(stride_C_host_ref.data()); + + stride_D_ref.reset(options.groups); + stride_D_ref.copy_from_host(stride_D_host_ref.data()); + + stride_S_ref.reset(options.groups); + stride_S_ref.copy_from_host(stride_S_host_ref.data()); + + stride_S.reset(options.groups); + stride_S.copy_from_host(stride_S_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_tensor(block_A, seed + 2023); + initialize_quant_tensor(block_B, seed + 2022); + initialize_tensor(block_C, seed + 2021); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + problem_sizes.reset(options.groups); + // Reverse MN -> NM for SwapAB + for (int32_t i = 0; i < options.groups; ++i) { + auto [M, N, K] = options.problem_sizes_host[i]; + options.problem_sizes_host[i] = make_tuple(N, M, K); + } + problem_sizes.copy_from_host(options.problem_sizes_host.data()); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options const& options, bool host_problem_shapes_available = true) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + + if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::DirectConvert) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } + else if constexpr (Gemm::CollectiveMainloop::KernelConversionMode == Gemm::CollectiveMainloop::ConversionMode::ConvertAndScale) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_B.get(), stride_B.get(), ptr_A.get(), stride_A.get(), ptr_scale.get(), stride_S.get(), options.k}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } + else { + std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; + exit(-1); + } + return arguments; +} + +bool verify(Options const& options) { + bool passed = true; + + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; + + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA_verif = typename GemmRef::GemmKernel::StrideA; + using StrideB_verif = typename GemmRef::GemmKernel::StrideB; + using StrideC_verif = typename GemmRef::GemmKernel::StrideC; + using StrideD_verif = typename GemmRef::GemmKernel::StrideD; + + const ElementD epsilon(1e-2f); + const ElementD non_zero_floor(1e-4f); + + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto N = get<0>(problem); + auto M = get<1>(problem); + auto K = get<2>(problem); + if (M == 0) { + continue; + } + else { + StrideA_verif stride_A_verif; + StrideB_verif stride_B_verif; + + stride_A_verif = cutlass::make_cute_packed_stride(StrideA_verif{}, cute::make_shape(M, K, 1)); + stride_B_verif = cutlass::make_cute_packed_stride(StrideB_verif{}, cute::make_shape(N, K, 1)); + + const int scale_k = 1; + auto layout_B = make_layout(cute::make_shape(N, K, Int<1>{}), stride_B_host.at(i)); + auto layout_scale_zero = make_layout(cute::make_shape(N, scale_k, Int<1>{}), stride_S_host_ref.at(i)); + cudaStream_t stream = cudaStreamDefault; + cutlass::dequantize(block_B_dq.get() + offset_B_dq.at(i), block_B.get() + offset_B.at(i), layout_B, block_scale.get() + offset_scale.at(i), block_zero.get() + offset_zero.at(i), layout_scale_zero, options.k, stream); + + // + // Compute reference output + // + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + {block_A.get() + offset_A.at(i), stride_A_verif, block_B_dq.get() + offset_B_dq.at(i), stride_B_verif}, + {{alpha_host.at(i), beta_host.at(i)}, block_C.get() + offset_C.at(i), stride_C_host_ref.at(i), block_ref_D.get() + offset_D.at(i), stride_D_host_ref.at(i)} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + passed &= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N, epsilon, non_zero_floor); + std::cout << "Group: " << i << " Status: " << passed << std::endl; + } + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + std::cout << "We passed all checks\n"; + // Check if output from CUTLASS kernel and reference kernel are equal or not + MixedDtypeResult result; + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + grouped_mixed_dtype_profiling(gemm, options, result, alpha_host, beta_host); + if (!result.passed) { + exit(-1); + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 9 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + std::cout << "Running in no scale mode." << std::endl; + run(options, false); + } + else if (options.mode == MixedDtypeGemmMode::ScaleOnly) { + std::cout << "Running in group scale mode." << std::endl; + run(options, false); + } +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt b/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000..4c21cd48 --- /dev/null +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/CMakeLists.txt @@ -0,0 +1,112 @@ +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + +set(TEST_RANDOM --iterations=0) # Random problem sizes +set(TEST_RANDOM_LARGE_GROUP --groups=100 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=2.0 --beta=2.0 --groups=100 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=0.25 --iterations=1) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=16 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=100 --iterations=0) # Fixed problem sizes + +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=100 --iterations=0) # Small problem sizes + +set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes +set(TEST_RANDOM_PERF_LARGE_GROUP --groups=100 --iterations=10) # Random problem sizes + +set(TEST_DIRECT_BATCHED --m=2048 --n=5120 --k=8192 --mode=0 --iterations=0) # Direct conversion + +set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --c=8192 --mode=1 --iterations=0) # Per Column scaling + +cutlass_example_add_executable( + 69_hopper_mixed_dtype_grouped_gemm + 69_hopper_mixed_dtype_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL +) + +cutlass_example_add_executable( + 69_hopper_int4_fp8_grouped_gemm + 69_hopper_int4_fp8_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL +) + +cutlass_example_add_executable( + 69_hopper_int4_bf16_grouped_gemm + 69_hopper_int4_bf16_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL +) diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp b/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp new file mode 100644 index 00000000..db391cce --- /dev/null +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "../55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp" + +template +class GroupedMixedDtypeOptions : public MixedDtypeOptions { +public: + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using UnderlyingProblemShape = typename ProblemShape::UnderlyingProblemShape; + + int groups = 6; + int c = 512; + std::string benchmark_path; + std::vector problem_sizes_host; + + GroupedMixedDtypeOptions() : MixedDtypeOptions() + { + m = 1024; + n = 2048; + k = 512; + }; + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("c", c); + MixedDtypeOptions::parse(argc, args); + + problem_sizes_host = benchmark_path.empty() ? randomize_problems(cmd) : load_benchmark_problems(); + } + + std::ostream& print_usage(std::ostream& out) const { + out << "69_hopper_mixed_dtype_grouped_gemm\n\n" + << "Options:\n" + << " --help Display this usage statement\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems\n" + << " --mode= The mode to run the gemm\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations\n" + << " --warmup= Number of warmup iterations\n" + << " --benchmark= Executes a benchmark problem size\n"; + return out; + } + + double gflops(double runtime_s) const { + uint64_t fmas = std::accumulate(problem_sizes_host.begin(), problem_sizes_host.end(), 0ULL, + [](uint64_t sum, const UnderlyingProblemShape& problem) { + return sum + static_cast(cute::get<0>(problem)) * + static_cast(cute::get<1>(problem)) * + static_cast(cute::get<2>(problem)); + }); + return (2.0 * fmas) / (runtime_s * 1e9); + } + +private: + static constexpr int tma_alignment_bits = 128; + const int alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + std::vector randomize_problems(cutlass::CommandLine& cmd) { + std::vector problems; + problems.reserve(groups); + + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + for (int i = 0; i < groups; ++i) { + int m = (cmd_line_m >= 0) ? cmd_line_m : alignment * ((rand() % 64) + 1); + int n = (cmd_line_n >= 0) ? cmd_line_n : this->n; + int k = (cmd_line_k >= 0) ? cmd_line_k : this->k; + + if (k % alignment != 0) { + throw std::runtime_error("Error: k dimension must be a multiple of " + std::to_string(alignment)); + } + problems.push_back({m, n, k}); + } + return problems; + } + + std::vector load_benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file) { + throw std::runtime_error("Failed to open benchmark file: " + benchmark_path); + } + + std::vector problems; + int idx; + std::string extent_str; + + while (file >> idx >> extent_str) { + if (idx < 0 || extent_str.empty()) break; + + std::vector tokens; + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + cutlass::gemm::GemmCoord extent; + for (int i = 0; i < std::min(3, static_cast(tokens.size())); ++i) { + int x = std::stoi(tokens[i]); + extent.at(i) = (x % alignment) ? x + (alignment - (x % alignment)) : x; + } + + if (extent.product()) { + problems.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problems.size()); + return problems; + } +}; + +template +void grouped_mixed_dtype_profiling( + Gemm& gemm, + const GroupedMixedDtypeOptions& options, + MixedDtypeResult& result, + const std::vector& alpha_host, + const std::vector& beta_host) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Sizes, Alpha, Beta\n"; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host[i] << ", " << alpha_host[i] << ", " << beta_host[i] << '\n'; + } + std::cout << " Groups : " << options.groups << '\n' + << " Avg runtime : " << result.avg_runtime_ms << " ms\n" + << " GFLOPS : " << result.gflops << '\n'; +} diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu index 1486a3c6..fa65e508 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -124,13 +124,14 @@ constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // A using ElementAccumulator = float; // Element type for internal accumulation // using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands +using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands constexpr int OutputSFVectorSize = 16; using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor< cutlass::epilogue::thread::SiLu, OutputSFVectorSize, ElementD, ElementAccumulator, - ElementSF, + ElementSFD, LayoutC, ElementC>; diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu b/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu index 9ac6f589..220f5fa8 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha_gen.cu @@ -466,7 +466,6 @@ struct ExampleRunner { int max_seqlen_kv = 0; for (auto e : seqlen_kv) { - // if (options.varlen) std::cout << "seqlen " << e << std::endl; max_seqlen_kv = std::max(e, max_seqlen_kv); } diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 4c9e784a..c840d8ba 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -29,11 +29,11 @@ set_property( SOURCE 77_blackwell_fmha.cu - PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v") + PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") set_property( SOURCE 77_blackwell_fmha_gen.cu - PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0 --ptxas-options -v") + PROPERTY COMPILE_FLAGS "--use_fast_math -ftemplate-backtrace-limit=0") set(TEST_BASIC --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=no) set(TEST_CAUSAL --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=causal) diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 3f063f56..60b411a3 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -529,7 +529,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); -// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _); // Each thread owns a single row using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem @@ -822,9 +821,7 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); -// tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); -// tTMEM_LOADsO_i.data() = tTMEM_LOADsO_i.data().get() + sO.layout()(_0{}, i * kCorrectionTileSize, _0{}); Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); @@ -939,8 +936,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { cute::mul(out, scale_f32x2, in); tTMrO_i(j) = out.x; tTMrO_i(j+1) = out.y; - //tTMrO(j) = scale * tTMrO(j); - //tTMrO(j+1) = scale * tTMrO(j+1); } copy_out(i); diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index 38b26196..655c080e 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -538,7 +538,6 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); -// Tensor tScS_P = tScS.compose(make_layout(make_shape(make_shape(_128{}, _32{}), _4{}, _1{}, _1{})))(_, _1{}, _, _); // Each thread owns a single row using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem @@ -956,8 +955,6 @@ struct Sm100FmhaGenMainloopWarpspecialized { cute::mul(out, scale_f32x2, in); tTMrO_i(j) = out.x; tTMrO_i(j+1) = out.y; - //tTMrO(j) = scale * tTMrO(j); - //tTMrO(j+1) = scale * tTMrO(j+1); } copy_out(i); diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp index c201d4f0..3f37d725 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp @@ -188,21 +188,15 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { // Q1 int q0_index = get<0>(blk_coord); -// pipeline_q.producer_acquire(pipeline_q_producer_state); - // copy_with_limit(tiled_copy_q, tQcQ, limitQ, tQgQ, tQsQ(_, _, _, _, pipeline_q_producer_state.index()); auto load_q = [&](int q_index, auto& state) { pipeline_q.producer_acquire(state); -// using Vec = Element; -// auto vzero = Element(0); // q is always loaded masked using Vec = uint128_t; Vec vzero = uint128_t(0, 0); - //auto src = recast(tQgQ(_, _, _, _, q_index)); auto src = recast(tQgQ(_, _, _, _)); auto dst = recast(tQsQ(_, _, _, _, state.index())); - // auto c = tQcQ(_, _, _, _, q_index); auto c = tQcQ(_, _, _, _); int vlen = sizeof(Vec) / sizeof(Element); CUTLASS_PRAGMA_UNROLL @@ -220,7 +214,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { }; load_q(q0_index, pipeline_q_producer_state); -// pipeline_q.producer_commit(pipeline_q_producer_state, cutlass::arch::cpasync_barrier_arrive); ++pipeline_q_producer_state; auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{})); @@ -287,8 +280,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index())); pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); } else { -// using Vec = Element; -// auto vzero = Element(0); using Vec = uint128_t; Vec vzero = uint128_t(0, 0); auto src = recast(tKgK(_, _, _, _, k_index)); @@ -322,8 +313,6 @@ struct Sm100FmhaLoadCpAsyncWarpspecialized { copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index())); pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); } else { -// using Vec = Element; -// auto vzero = Element(0); using Vec = uint128_t; Vec vzero = uint128_t(0, 0); auto src = recast(tVgV(_, _, _, _, v_index)); diff --git a/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp b/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp index 003fff65..40239c56 100644 --- a/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp +++ b/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp @@ -149,8 +149,6 @@ void __global__ fmha_fwd_gen_reference_kernel( __syncthreads(); for (int idx_d = threadIdx.x; idx_d < kDim; idx_d += blockDim.x) { - -// printf("O[%d,%d,%d] = %f\n", idx_d, idx_h, idx_b, mS[idx_d]); mO(_0{}, idx_d, make_coord(idx_h, idx_b)) = static_cast(mS[idx_d]); } } diff --git a/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu new file mode 100644 index 00000000..48e1da6c --- /dev/null +++ b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu @@ -0,0 +1,475 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A Blackwell CUTLASS GEMM example for FastFP32 (using BF16 to emulate SGEMM). + + This example demonstrates how to run an emulated SGEMM with BF16x9 on an NVIDIA GPU that supports + NVIDIA's Blackwell architecture (SM100a). Using BF16x9 leverages tensor cores, providing much + greater throughput compared to SIMT instructions. + + To emulate SGEMM using BF16x9, the A and B matrices are decomposed to three lower precision elements: + a = a1 + a2 + a3 + b = b1 + b2 + b3 + + One FP32 MAC is equivalent to 9 MACs using BF16: + a * b + c = a1*b1 + a1*b2 + a1*b3 + a2*b1 + a2*b2 + a2*b3 + a3*b1 + a3*b2 + a3*b3 + c + + Example 27 demonstrates a similar technique for emulated SGEMM using TF32 with the Ampere architecture. + + Usage: + + $ ./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm --m=8192 --n=8192 --k=8192 +*/ + + + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + +// Kernel Perf config +using ClusterTileShape = Shape<_256,_128,_16>; // Cluster-level tile shape +using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster +using CtaTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // Threadblock-level tile shape +using MmaTileShape = Shape<_256,_128,_16>; // Mma instruction shape + +// Build the epilogue +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + CtaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + +// Build the mainloop +// Note: Emulated BF16x9 kernels need to manually specify a mainloop schedule and cannot use KernelScheduleAuto +using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecializedFastFP32SmemSm100; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + +// Compose into a kernel +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + + Options(): + help(false), + m(8192), n(8192), k(8192), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "78_blackwell_emulated_bf16x9_gemm\n\n" + << " Blackwell emulated BF16x9 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "78_blackwell_emulated_bf16x9_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 100a. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt b/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt new file mode 100644 index 00000000..1b36a4fd --- /dev/null +++ b/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +cutlass_example_add_executable( + 78_blackwell_emulated_bf16x9_gemm + 78_blackwell_emulated_bf16x9_gemm.cu + ) +endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 21166302..079adff4 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -146,14 +146,16 @@ foreach(EXAMPLE 64_ada_fp8_gemm_grouped 65_distributed_gemm 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling + 69_hopper_mixed_dtype_grouped_gemm 70_blackwell_gemm 71_blackwell_gemm_with_collective_builder 72_blackwell_narrow_precision_gemm - 73_blackwell_gemm_preferred_cluster + 73_blackwell_gemm_preferred_cluster 74_blackwell_gemm_streamk 75_blackwell_grouped_gemm 76_blackwell_conv 77_blackwell_fmha + 78_blackwell_emulated_bf16x9_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/examples/README.md b/examples/README.md index dddfa4c3..ec39bf22 100644 --- a/examples/README.md +++ b/examples/README.md @@ -246,8 +246,6 @@ Hopper GEMM kernel with Top-K and softmax epilogue fusion. -[//]: # - * [70_blackwell_gemm](70_blackwell_gemm) Simple dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. @@ -280,8 +278,6 @@ Blackwell SM100 FMHA kernel -[//]: # - # CuTe - Programming Examples Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/). diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index ecc9d33d..b1490b02 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -171,20 +171,6 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) - -#if 0 - if (thread0()) { - print(" sA: "); print( sA); print("\n"); - print(" sB: "); print( sB); print("\n"); - print(thr_mma); - print("tCsA: "); print(tCsA); print("\n"); - print("tCsB: "); print(tCsB); print("\n"); - print("tCrA: "); print(tCrA); print("\n"); - print("tCrB: "); print(tCrB); print("\n"); - print("tCrC: "); print(tCrC); print("\n"); - } -#endif - // // PREDICATION // @@ -200,7 +186,6 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, // Allocate the preds for MMA- and MMA_MN-modes Tensor tCpA = make_tensor(make_shape(size<0>(tCsA), size<1>(tCsA))); Tensor tCpB = make_tensor(make_shape(size<0>(tCsB), size<1>(tCsB))); - // Populate the predicates on M and N CUTE_UNROLL for (int i = 0; i < size(tCpA); ++i) { @@ -210,18 +195,6 @@ cooperative_gemm_predication(ThrMMA const& thr_mma, for (int i = 0; i < size(tCpB); ++i) { tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB)); } - -#if 0 - if (thread0()) { - print(" cA: "); print( cA); print("\n"); - print(" cB: "); print( cB); print("\n"); - print("tCcA: "); print(tCcA); print("\n"); - print("tCcB: "); print(tCcB); print("\n"); - print_tensor(tCpA); - print_tensor(tCpB); - } -#endif - // // PREFETCH k_block = 0 // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block @@ -330,24 +303,6 @@ cooperative_gemm_no_predication(uint32_t thread_idx, Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K - -#if 0 - if (thread0()) { - print(" sA: "); print(sA); print("\n"); - print(" sB: "); print(sB); print("\n"); - print(thr_mma); print("\n"); - print("tCrA: "); print(tCrA); print("\n"); - print("tCrB: "); print(tCrB); print("\n"); - print("tCrC: "); print(tCrC); print("\n"); - print(smem_thr_copy_A); print("\n"); - print("tCsA: "); print(tCsA); print("\n"); - print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n"); - print(smem_thr_copy_B); print("\n"); - print("tCsB: "); print(tCsB); print("\n"); - print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n"); - } -#endif - // // PREFETCH // @@ -434,14 +389,6 @@ cooperative_gemm(uint32_t thread_idx, // Clear accumulators clear(tCrC); - -#if 0 - if (thread0()) { - print(" sC: "); print(sC); print("\n"); - print(" tCsC: "); print(tCsC); print("\n"); - } -#endif - if constexpr (is_constant::value) { detail::cooperative_gemm_no_predication( thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 740565b5..9700b3f2 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -248,15 +248,6 @@ copy(AutoVectorizingCopyWithAssumedAlignment const&, // Recast Tensor src_v = recast(src); Tensor dst_v = recast(dst); - -#if 0 - if (thread0()) { - print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits); - print(" "); print(src); print(" => "); print(src_v); print("\n"); - print(" "); print(dst); print(" => "); print(dst_v); print("\n"); - } -#endif - return copy_if(TrivialPredTensor{}, src_v, dst_v); } else { return copy_if(TrivialPredTensor{}, src, dst); @@ -374,15 +365,6 @@ copy(Copy_Traits const& atom, // Copy_Traits m // Construct a new concrete Atom of the vector size using BulkAtom = Copy_Atom, CT_Args...>, SrcType>; auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; }); - -#if 0 - if (thread0()) { - print("copy blkcp -- found a max_common_layout of "); print(tiler); print("\n"); - print(" "); print(src); print("\n"); - print(" "); print(dst); print("\n"); - } -#endif - return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler)); } diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index a81b4e33..e1950b92 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -61,6 +61,7 @@ # define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED # define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED # define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif #if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 10a7d839..a157008c 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -219,7 +219,7 @@ to_CUtensorMapDataType() { if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else - #if defined(CUDA_VERSION) && CUDA_VERSION > 12060 + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B;} else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;} else @@ -231,6 +231,7 @@ to_CUtensorMapDataType() { if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;} else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else #endif + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } } @@ -247,23 +248,17 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { case SmemSwizzleBits::B64: assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits."); return CU_TENSOR_MAP_SWIZZLE_64B; - #if (0) - case SmemSwizzleBits::B128: - assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 128B swizzle bits."); - return CU_TENSOR_MAP_SWIZZLE_128B; - - #else case SmemSwizzleBits::B128: switch (b) { default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B; - #if defined(CUDA_VERSION) && CUDA_VERSION > 12060 + + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B; #endif + } - #endif - } } diff --git a/include/cute/arch/mma_sm100_desc.hpp b/include/cute/arch/mma_sm100_desc.hpp index 57934f24..3d748061 100644 --- a/include/cute/arch/mma_sm100_desc.hpp +++ b/include/cute/arch/mma_sm100_desc.hpp @@ -391,7 +391,6 @@ CUTE_HOST_DEVICE constexpr auto to_UMMAFormat() { if constexpr (is_same_v) { return MXF8F6F4Format::E4M3; } else if constexpr (is_same_v) { return MXF8F6F4Format::E5M2; } else - if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else @@ -399,7 +398,6 @@ CUTE_HOST_DEVICE constexpr auto to_UMMAFormat() { if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else if constexpr (is_same_v) { return MXF8F6F4Format::E2M1; } else if constexpr (is_same_v) { return MXF4Format::E2M1; } else - { static_assert(sizeof(T) == 0, "Unknown type for UMMAFormat"); } } diff --git a/include/cute/arch/mma_sm100_umma.hpp b/include/cute/arch/mma_sm100_umma.hpp index 26ef131c..d954544f 100644 --- a/include/cute/arch/mma_sm100_umma.hpp +++ b/include/cute/arch/mma_sm100_umma.hpp @@ -49,7 +49,10 @@ template > 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( @@ -99,7 +91,10 @@ template > 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( @@ -150,7 +134,10 @@ template > 32); - print(desc_i); - print("UMMA TMEM addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif uint32_t mask[4] = {0, 0, 0, 0}; if (cute::elect_one_sync()) { asm volatile( @@ -201,7 +178,10 @@ template > 32); - print(desc_i); - print("UMMA TMEM addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif uint32_t mask[4] = {0, 0, 0, 0}; if (cute::elect_one_sync()) { asm volatile( @@ -245,13 +215,101 @@ struct SM100_MMA_F16BF16_TS } }; +template +struct SM100_MMA_F16BF16_SS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16_SS_SCALED N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_TS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_TS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16_TS_SCALED N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_TS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + template struct SM100_MMA_TF32_2x1SM_SS { - static_assert(M == 128 || M == 256, "SM100_MMA_TF32 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -266,15 +324,6 @@ struct SM100_MMA_TF32_2x1SM_SS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) - -#if 0 - if (thread0()) { - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("Umma TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -289,7 +338,7 @@ struct SM100_MMA_TF32_2x1SM_SS "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS without SM100_MMA_TF32_2x1SM_SS"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); #endif } }; @@ -299,8 +348,8 @@ template struct SM100_MMA_F16BF16_2x1SM_SS { - static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -315,15 +364,6 @@ struct SM100_MMA_F16BF16_2x1SM_SS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) - -#if 0 - if (thread0()) { - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("Umma TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -338,7 +378,7 @@ struct SM100_MMA_F16BF16_2x1SM_SS "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); #endif } }; @@ -349,9 +389,9 @@ template struct SM100_MMA_TF32_2x1SM_TS { - static_assert(M == 128 || M == 256, "SM100_MMA_TF32 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32 N-mode size should be a multiple of 16 between 16 and 256."); - static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32 A from TMEM can't be transposed"); + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; @@ -366,14 +406,6 @@ struct SM100_MMA_TF32_2x1SM_TS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) -#if 0 - if (thread0()) { - print(desc_i); - print("Umma TMEM-A addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("Umma TMEM-C addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -399,9 +431,9 @@ template struct SM100_MMA_F16BF16_2x1SM_TS { - static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16 N-mode size should be a multiple of 16 between 16 and 256."); - static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16 A from TMEM can't be transposed"); + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; @@ -416,14 +448,6 @@ struct SM100_MMA_F16BF16_2x1SM_TS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) -#if 0 - if (thread0()) { - print(desc_i); - print("Umma TMEM-A addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("Umma TMEM-C addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -438,7 +462,91 @@ struct SM100_MMA_F16BF16_2x1SM_TS "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SCALED N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_TS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_SCALED N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); #endif } }; @@ -448,9 +556,9 @@ template struct SM100_MMA_S8_SS { - static_assert(is_same_v, "SM100_MMA_S8 result type can only be int32_t."); - static_assert(M == 64 || M == 128, "SM100_MMA_S8 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 8 between 8 and 256."); + static_assert(is_same_v, "SM100_MMA_S8_SS result type can only be int32_t."); + static_assert(M == 64 || M == 128, "SM100_MMA_S8_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_SS N-mode size should be 8 or a multiple of 16 between 16 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -465,16 +573,6 @@ struct SM100_MMA_S8_SS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( @@ -488,7 +586,7 @@ struct SM100_MMA_S8_SS "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); #endif } }; @@ -499,9 +597,9 @@ template struct SM100_MMA_S8_TS { - static_assert(M == 64 || M == 128, "SM100_MMA_S8 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 8 between 8 and 256."); - static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8 A from TMEM can't be transposed"); + static_assert(M == 64 || M == 128, "SM100_MMA_S8_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_TS N-mode size should be 8 or a multiple of 16 between 16 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8_TS A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; @@ -516,16 +614,6 @@ struct SM100_MMA_S8_TS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print("UMMA TMEM addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( @@ -539,7 +627,7 @@ struct SM100_MMA_S8_TS "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_TS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); #endif } }; @@ -549,8 +637,8 @@ template struct SM100_MMA_S8_2x1SM_SS { - static_assert(M == 128 || M == 256, "SM100_MMA_S8 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(M == 128 || M == 256, "SM100_MMA_S8_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -565,16 +653,6 @@ struct SM100_MMA_S8_2x1SM_SS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("Umma TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -589,7 +667,7 @@ struct SM100_MMA_S8_2x1SM_SS "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); #endif } }; @@ -600,9 +678,9 @@ template struct SM100_MMA_S8_2x1SM_TS { - static_assert(M == 128 || M == 256, "SM100_MMA_S8 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 16 between 16 and 256."); - static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8 A from TMEM can't be transposed"); + static_assert(M == 128 || M == 256, "SM100_MMA_S8_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8_2x1SM_TS A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; @@ -617,16 +695,6 @@ struct SM100_MMA_S8_2x1SM_TS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print("Umma TMEM addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("Umma TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -641,13 +709,15 @@ struct SM100_MMA_S8_2x1SM_TS "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_TS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); #endif } }; struct SM100_MMA_F8F6F4_SS { + + using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; @@ -661,16 +731,6 @@ struct SM100_MMA_F8F6F4_SS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( @@ -684,7 +744,7 @@ struct SM100_MMA_F8F6F4_SS "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); #endif } }; @@ -694,8 +754,8 @@ template struct SM100_MMA_MXF8F6F4_SS { - static_assert(M == 64 || M == 128, "SM100_MMA_MXF8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4 N-mode size should be a multiple of 8 between 8 and 256."); + static_assert(M == 128, "SM100_MMA_MXF8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -714,19 +774,6 @@ struct SM100_MMA_MXF8F6F4_SS uint32_t const& tsfb_addr) { #if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptorBlockScaled desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - print("Umma SFA TMEM addr: 0x%08x\n", tsfa_addr); - print("Umma SFB TMEM addr: 0x%08x\n", tsfb_addr); - print("===================================\n"); - } -#endif if (cute::elect_one_sync()) { asm volatile( "{\n\t" @@ -738,7 +785,7 @@ struct SM100_MMA_MXF8F6F4_SS : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(tsfa_addr), "r"(tsfb_addr)); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); #endif } }; @@ -749,9 +796,12 @@ template struct SM100_MMA_F8F6F4_TS { - static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 8 between 8 and 256."); - static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4 A from TMEM can't be transposed"); + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F8F6F4_TS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4_TS A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; @@ -766,16 +816,6 @@ struct SM100_MMA_F8F6F4_TS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print("UMMA TMEM addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[4] = {0, 0, 0, 0}; asm volatile( @@ -789,7 +829,7 @@ struct SM100_MMA_F8F6F4_TS "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); #endif } }; @@ -800,9 +840,9 @@ template struct SM100_MMA_F8F6F4_2x1SM_TS { - static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 16 between 16 and 256."); - static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4 A from TMEM can't be transposed"); + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4_2x1SM_TS A from TMEM can't be transposed"); using DRegisters = void; using ARegisters = uint32_t[1]; @@ -817,15 +857,6 @@ struct SM100_MMA_F8F6F4_2x1SM_TS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print("UMMA TMEM addr: 0x%08x\n", tmem_a); - print(reinterpret_cast(desc_b)); - print("UMMA TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -840,13 +871,13 @@ struct SM100_MMA_F8F6F4_2x1SM_TS "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); #endif } }; struct SM100_MMA_F8F6F4_2x1SM_SS -{ +{ using DRegisters = void; using ARegisters = uint64_t[1]; using BRegisters = uint64_t[1]; @@ -860,16 +891,6 @@ struct SM100_MMA_F8F6F4_2x1SM_SS uint64_t const& idescE) { #if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptor desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("Umma TMEM addr: 0x%08x\n", tmem_c); - } -#endif if (cute::elect_one_sync()) { uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; asm volatile( @@ -884,7 +905,7 @@ struct SM100_MMA_F8F6F4_2x1SM_SS "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); #endif } }; @@ -894,8 +915,8 @@ template struct SM100_MMA_MXF8F6F4_2x1SM_SS { - static_assert(M == 128 || M == 256, "SM100_MMA_MXF8F6F4 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4 N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(M == 256, "SM100_MMA_MXF8F6F4_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -912,19 +933,6 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS uint32_t const& tsfb_addr) { #if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) -#if 0 - if (thread0()) { - UMMA::InstrDescriptorBlockScaled desc_i; - desc_i.desc_ = uint32_t(idescE >> 32); - print(desc_i); - print(reinterpret_cast(desc_a)); - print(reinterpret_cast(desc_b)); - print("Umma TMEM addr: 0x%08x\n", tmem_c); - print("Umma SFA TMEM addr: 0x%08x\n", tsfa_addr); - print("Umma SFB TMEM addr: 0x%08x\n", tsfb_addr); - print("===================================\n"); - } -#endif if (cute::elect_one_sync()) { asm volatile( "{\n\t" @@ -937,7 +945,7 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS "r"(tsfa_addr), "r"(tsfb_addr)); } #else - CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); #endif } }; @@ -948,9 +956,9 @@ template struct SM100_MMA_MXF4_SS { - static_assert(M == 128, "SM100_MMA_MXF4 M-mode size should be 128 for 1 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4 N-mode size should be a multiple of 16 between 16 and 256."); - static_assert((VS == 16) || (VS == 32), "Vector size can only be 16 or 32."); + static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster OMMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4_SS N-mode size should be a multiple of 8 between 8 and 256."); + static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_SS Vector size can only be 16 or 32."); using DRegisters = void; using ARegisters = uint64_t[1]; @@ -1013,9 +1021,9 @@ template struct SM100_MMA_MXF4_2x1SM_SS { - static_assert(M == 128 || M == 256, "SM100_MMA_MXF4 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); - static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4 N-mode size should be a multiple of 16 between 16 and 256."); - static_assert((VS == 16) || (VS == 32), "Vector size can only be 16 or 32."); + static_assert(M == 128 || M == 256, "SM100_MMA_MXF4_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_2x1SM_SS Vector size can only be 16 or 32."); using DRegisters = void; using ARegisters = uint64_t[1]; diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 96ffbec2..2d4eac73 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -521,15 +521,6 @@ make_cotiled_copy(Copy_Atom const& copy_atom, // Check validity CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); - -#if 0 - if (thread0()) { - print("data_layout : "); print(data_layout); print("\n"); - print("atom_tv_layout : "); print(atom_tv_layout); print("\n"); - print("layout_tv_data : "); print(layout_tv_data); print("\n"); - } -#endif - // // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them // @@ -552,15 +543,6 @@ make_cotiled_copy(Copy_Atom const& copy_atom, // (tid,vid) -> tile_coord auto layout_tv = composition(left_inverse(tile2data), layout_tv_data); - -#if 0 - if (thread0()) { - print("tiler : "); print(tiler); print("\n"); - print("tile2data : "); print(tile2data); print("\n"); - print("layout_tv : "); print(layout_tv); print("\n"); - } -#endif - return make_tiled_copy_impl(copy_atom, layout_tv, tiler); } diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp index bc0d956b..cd344fd5 100644 --- a/include/cute/atom/copy_traits_sm100.hpp +++ b/include/cute/atom/copy_traits_sm100.hpp @@ -394,15 +394,6 @@ make_tmem_warp_partitioner(Tensor const& tmem) // wid -> tmem_coord auto layout_t_tmem = composition(inv_tmem_layout, atom_t_layout); - -#if 0 - if (thread0()) { - print("input : "); print(tmem.data()); print(" o "); print(tmem_layout); print("\n"); - print("atom_t_layout : "); print(atom_t_layout); print("\n"); - print("layout_tv_tmem : "); print(layout_tv_tmem); print("\n"); - } -#endif - // // Tiler -- Find the active elements in the TMEM tensor and generate a tiler to extract them // @@ -425,15 +416,6 @@ make_tmem_warp_partitioner(Tensor const& tmem) // wid -> tile_coord auto layout_tv = composition(left_inverse(tile2tmem), layout_t_tmem); - -#if 0 - if (thread0()) { - print("tiler : "); print(tiler); print("\n"); - print("tile2tmem : "); print(tile2tmem); print("\n"); - print("layout_tv : "); print(layout_tv); print("\n"); - } -#endif - return make_tiler_impl(layout_tv, tiler); } diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index b4fdec0d..9c30ca53 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -1374,19 +1374,6 @@ tma_partition(Copy_Atom const& copy_atom, // Transform tile mode and coalesce Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) - -#if 0 - if (thread0()) { - print("cta_coord : "); print(cta_coord); print("\n"); - print("cta_layout : "); print(cta_layout); print("\n"); - print("gtensor : "); print(gtensor); print("\n"); - print("stensor : "); print(stensor); print("\n"); - print("layout_V : "); print(layout_V); print("\n"); - print("gtensor_v : "); print(gtensor_v); print("\n"); - print("stensor_v : "); print(stensor_v); print("\n"); - } -#endif - // Offset inside the TMA-mode for the multicast auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 35a5c8a1..fe2f3e0a 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -157,7 +157,6 @@ struct MMA_Atom> || (sizeof_bits_v::value_type> == 8 && (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) - , "Expecting ValTypeA type"); return make_tensor(static_cast(atensor)); } else { diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index 71a9dd2a..ff7d5c55 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -59,7 +59,6 @@ namespace UMMA { // Common layouts for UMMA Shared Memory // ////////////////////////////////////////////////// -// TODO: Extend for remaining sm100 new layouts using cute::GMMA::Layout_MN_INTER_Atom; using cute::GMMA::Layout_MN_SW32_Atom; using cute::GMMA::Layout_MN_SW64_Atom; @@ -275,19 +274,6 @@ make_umma_desc(Tensor const& tensor) } else { static_assert(MajorMode != UMMA::Major::MN && MajorMode != UMMA::Major::K, "Unrecognized MajorMode!"); } - -#if 0 - // DEBUG and SANITY - assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation - assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later - if (thread0()) { - print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); - print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); - //print(" desc canonical layout: "); print(canonical_layout); print("\n"); - print(desc); - } -#endif - return desc; } @@ -514,7 +500,7 @@ struct tmem_frg : tmem_frg_base "UMMA_2SM only accepts Interleaved or Duplicated"); static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, "UMMA_2SM M-mode size should be 32 or 64 or 128."); - if constexpr (M_MMA == 32) // TODO: Implement Duplicated mode for M_MMA = 32 + if constexpr (M_MMA == 32) { static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved, "Only TmemAllocMode::Interleaved is supported for UMMA_2SM M_MMA=32"); // The "1x4" layout atom: (M,N) -> tmem_addr @@ -1013,7 +999,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_SS supports 32bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -1077,7 +1063,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS supports 16bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -1142,7 +1128,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_TS supports 32bit types"); using FrgTypeA = UMMA::tmem_frg_1sm; using FrgTypeB = UMMA::smem_desc; @@ -1208,7 +1194,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS supports 16bit types"); using FrgTypeA = UMMA::tmem_frg_1sm; using FrgTypeB = UMMA::smem_desc; @@ -1261,6 +1247,155 @@ struct MMA_Traits +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + template == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_SS supports 32bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -1338,7 +1473,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS supports 16bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -1404,7 +1539,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32 supports 32bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_TS supports 32bit types"); using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; @@ -1470,7 +1605,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16 supports 16bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS supports 16bit types"); using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; @@ -1523,6 +1658,152 @@ struct MMA_Traits +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + template @@ -1534,7 +1815,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_SS supports 8bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -1599,7 +1880,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_TS supports 8bit types"); using FrgTypeA = UMMA::tmem_frg_1sm; using FrgTypeB = UMMA::smem_desc; @@ -1663,7 +1944,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_SS supports 8bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -1728,7 +2009,7 @@ struct MMA_Traits == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8 supports 8bit types"); + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_TS supports 8bit types"); using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; @@ -1795,16 +2076,18 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); - + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS supports types with leq 8bit types"); + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_1sm; static_assert(sizeof_bits_v <= sizeof_bits_v && sizeof_bits_v <= sizeof_bits_v); - static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_F8F6F4 N-mode size should be a multiple of 8 between 8 and 256."); // Logical shape-K is always 256bits, transform to units of elements constexpr static int K = 32; @@ -1863,7 +2146,7 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4 supports types with leq 8bit types"); + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_SS supports types with leq 8bit types"); // Logical shape-K is always 256bits, transform to units of elements constexpr static int K = 32; @@ -1953,7 +2236,7 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_TS supports types with leq 8bit types"); using FrgTypeA = UMMA::tmem_frg_1sm; using FrgTypeB = UMMA::smem_desc; @@ -2023,8 +2306,10 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); - + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types"); + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); + using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_2sm; @@ -2034,9 +2319,6 @@ struct MMA_Traits,Int,Int>; using ThrID = Layout<_2>; using ALayout = Layout,Int>>, @@ -2090,7 +2372,7 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_TS supports types with leq 8bit types"); using FrgTypeA = UMMA::tmem_frg_2sm; using FrgTypeB = UMMA::smem_desc; @@ -2159,7 +2441,7 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4 supports types with leq 8bit types"); + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS supports types with leq 8bit types"); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -2252,7 +2534,7 @@ struct MMA_Traits == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4 supports 4bit types"); + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4_SS supports 4bit types"); // Logical shape-K is always 256bits, transform to units of elements constexpr static int K = 64; @@ -2345,7 +2627,7 @@ struct MMA_Traits == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4 supports 4bit types"); + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4_2x1SM_SS supports 4bit types"); // Logical shape-K is always 256bits, transform to units of elements constexpr static int K = 64; diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index b9e78751..e3438f36 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -295,19 +295,6 @@ make_gmma_desc(Tensor const& tensor) } else { static_assert(MajorMode != Major::MN && MajorMode != Major::K, "Unrecognized MajorMode!"); } - -#if 0 - // DEBUG and SANITY - assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation - assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later - if (thread0()) { - print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); - print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); - //print(" desc canonical layout: "); print(canonical_layout); print("\n"); - print(desc); - } -#endif - return desc; } diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index d1b18694..c1a275c9 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1685,7 +1685,7 @@ blocked_product(Layout const& block, auto result = logical_product(append(block), append(tiler)); - return coalesce(zip(get<0>(result), get<1>(result)), tuple_repeat(Int<1>{})); + return zip(get<0>(result), get<1>(result)); } // raked_product -- Reproduce a block over a tiler with block-interleaving. @@ -1703,7 +1703,7 @@ raked_product(Layout const& block, auto result = logical_product(append(block), append(tiler)); - return coalesce(zip(get<1>(result), get<0>(result)), tuple_repeat(Int<1>{})); + return zip(get<1>(result), get<0>(result)); } // tile_to_shape -- Perform a product of a layout so that the result matches a target shape. @@ -1742,7 +1742,7 @@ tile_to_shape(Layout const& block, auto product_shape = ceil_div(target_shape, block_shape); - return coalesce(blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)), product_shape); + return blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)); } // diff --git a/include/cutlass/arch/mma_sm89.h b/include/cutlass/arch/mma_sm89.h index 80a62b13..a4a8b1cb 100644 --- a/include/cutlass/arch/mma_sm89.h +++ b/include/cutlass/arch/mma_sm89.h @@ -45,12 +45,21 @@ //////////////////////////////////////////////////////////////////////////////// #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) - -# define CUTLASS_ARCH_MMA_SM89_SUPPORTED 1 +# define CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED #endif -#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890) -# define CUTLASS_ARCH_MMA_SM89_ENABLED +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) +# define CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +# if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED) +# define CUTLASS_ARCH_MMA_F32_SM89_ENABLED +# endif + +# if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED) +# define CUTLASS_ARCH_MMA_F16_SM89_ENABLED +# endif #endif //////////////////////////////////////////////////////////////////////////////// @@ -132,7 +141,7 @@ struct Mma< void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { -#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -198,7 +207,7 @@ struct Mma< void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { -#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -264,7 +273,7 @@ struct Mma< void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { -#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -330,7 +339,7 @@ struct Mma< void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const { -#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -359,5 +368,275 @@ struct Mma< } }; +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP16 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F16 = fe4m3 * fe4m3 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F16 = fe4m3 * fe5m2 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F16 = fe5m2 * fe4m3 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F16 = fe5m2 * fe5m2 + F16 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + cutlass::half_t, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = cutlass::half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + } // namespace arch } // namespace cutlass diff --git a/include/cutlass/arch/mma_sparse_sm89.h b/include/cutlass/arch/mma_sparse_sm89.h index b6c1bfe3..27c40dc4 100644 --- a/include/cutlass/arch/mma_sparse_sm89.h +++ b/include/cutlass/arch/mma_sparse_sm89.h @@ -44,12 +44,13 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// #if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) - -# define CUTLASS_ARCH_SPARSE_MMA_SM89_SUPPORTED 1 +# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED #endif -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890) -# define CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +# if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED) +# define CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED +# endif #endif ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -113,7 +114,7 @@ struct SparseMma< int const id2 ) const { -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -198,7 +199,7 @@ struct SparseMma< int const id2 ) const { -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -283,7 +284,7 @@ struct SparseMma< int const id2 ) const { -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -368,7 +369,7 @@ struct SparseMma< int const id2 ) const { -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index f9f2be81..3a777113 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -51,10 +51,8 @@ # define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED #endif -#ifndef CUDA_ENABLE_PREFERRED_CLUSTER - #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) # define CUDA_ENABLE_PREFERRED_CLUSTER - #endif #endif namespace cutlass { diff --git a/include/cutlass/conv/collective/builders/sm90_common.inl b/include/cutlass/conv/collective/builders/sm90_common.inl index c0a48ebc..ddab1f7e 100644 --- a/include/cutlass/conv/collective/builders/sm90_common.inl +++ b/include/cutlass/conv/collective/builders/sm90_common.inl @@ -50,7 +50,7 @@ sm90_cluster_shape_to_im2col_tma_atom(UnimodalClusterShape unimodal_cluster_shap static_assert(cute::rank(unimodal_cluster_shape) == 1, "Use this function to figure out TMA for each mode individually."); - if constexpr (cute::size(unimodal_cluster_shape) == 1) { + if constexpr (UnimodalClusterShape::value == 1) { return cute::SM90_TMA_LOAD_IM2COL{}; } else { diff --git a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp index cca462d5..6486e243 100644 --- a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -516,13 +516,13 @@ public: } if (is_im2col_A || is_im2col_B) { - // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit - 1] - constexpr int32_t offset_limit = 1 << (16 / NumSpatialDimensions); + // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit] + constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1; auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B; for (int i = 0; i < problem_shape.RankS; ++i) { // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array - implementable = implementable && (flt_data[i+1] * problem_shape.dilation[i] >= 0) - && (flt_data[i+1] * problem_shape.dilation[i] <= (offset_limit - 1)); + implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0) + && ((flt_data[i+1] - 1) * problem_shape.dilation[i] <= offset_limit); } if (!implementable) { diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index f29d5780..74b0e011 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -392,12 +392,12 @@ public: if (is_im2col_A || is_im2col_B) { // Check valid filter offsets for TMA_LOAD_IM2COL, unsigned int ranging from [0, offset_limit - 1] - constexpr int32_t offset_limit = 1 << (16 / NumSpatialDimensions); + constexpr int32_t offset_limit = (1 << (16 / NumSpatialDimensions)) - 1; auto flt_data = (ConvOp == conv::Operator::kWgrad) ? problem_shape.shape_C : problem_shape.shape_B; for (int i = 0; i < problem_shape.RankS; ++i) { // flt_data array contains [K, T, R, S, C], so pure filter [T, R, S] starts from the second position in the array - implementable = implementable && (flt_data[i+1] * problem_shape.dilation[i] >= 0) - && (flt_data[i+1] * problem_shape.dilation[i] < offset_limit); + implementable = implementable && ((flt_data[i+1] - 1) * problem_shape.dilation[i] >= 0) + && ((flt_data[i+1] - 1) * problem_shape.dilation[i] < offset_limit); } if (!implementable) { diff --git a/include/cutlass/conv/kernel/conv_universal_dispatch.hpp b/include/cutlass/conv/kernel/conv_universal_dispatch.hpp deleted file mode 100644 index 8507a171..00000000 --- a/include/cutlass/conv/kernel/conv_universal_dispatch.hpp +++ /dev/null @@ -1,182 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - -#pragma once - -#include "cutlass/conv/kernel/conv_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/fast_math.h" -#include "cutlass/workspace.h" - -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::conv::kernel { - -//////////////////////////////////////////////////////////////////////////////// - -enum class DispatchMode { - VoidC // Select between voidC and non-voidC kernel based on beta scaling -}; - -// Dispatch between two ConvUniversal kernels -template -class ConvUniversalDispatch; - -//////////////////////////////////////////////////////////////////////////////// - -template < - class ProblemShape_, - class MainloopWithC_, class EpilogueWithC_, - class MainloopVoidC_, class EpilogueVoidC_, - class TileScheduler_ -> -class ConvUniversalDispatch< - DispatchMode::VoidC, - ConvUniversal, - ConvUniversal, - cute::void_t -> : public ConvUniversal { -private: - using KernelWithC = ConvUniversal; - using KernelVoidC = ConvUniversal; - using FusionArguments = cute::remove_cvref_t; - -public: - // Mainloop derived types - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - - // Epilogue derived types - static_assert(not cute::is_void_v); - static_assert( cute::is_void_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - static_assert(cute::is_same_v); - - // TileID scheduler - static_assert(cute::is_same_v); - - static constexpr int SharedStorageSize = cute::max(KernelWithC::SharedStorageSize, KernelVoidC::SharedStorageSize); - - static_assert(KernelWithC::MaxThreadsPerBlock == KernelVoidC::MaxThreadsPerBlock); - - static_assert(KernelWithC::MinBlocksPerMultiprocessor == KernelVoidC::MinBlocksPerMultiprocessor); - - using Arguments = typename KernelWithC::Arguments; - - struct Params { - typename KernelWithC::Params withC; - typename KernelVoidC::Params voidC; - - void const* ptr_C; - decltype(FusionArguments{}.beta) beta; - decltype(FusionArguments{}.beta_ptr) beta_ptr; - decltype(FusionArguments{}.dBeta) dBeta; - cutlass::KernelHardwareInfo hw_info{}; - }; - - static size_t - get_workspace_size(Arguments const& args) { - return KernelWithC::get_workspace_size(args); - } - - static cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) { - return KernelWithC::initialize_workspace(args, workspace, stream, cuda_adapter); - } - - static Params - to_underlying_arguments(Arguments const& args, void* workspace) { - return { - KernelWithC::to_underlying_arguments(args, workspace), - KernelVoidC::to_underlying_arguments(reinterpret_cast(args), workspace), - args.epilogue.ptr_C, - args.epilogue.thread.beta, - args.epilogue.thread.beta_ptr, - args.epilogue.thread.dBeta, - args.hw_info - }; - } - - static dim3 - get_grid_shape(Params const& params) { - return KernelWithC::get_grid_shape(params.withC); - } - - CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { - using namespace cute; - - bool run_voidC = false; - if (params.ptr_C == nullptr) { - run_voidC = true; - } - else if (params.beta_ptr == nullptr) { // Host scalar beta - run_voidC = params.beta == 0; - } - else if (get<0>(params.dBeta) == 0 && get<1>(params.dBeta) == 0) { // Device scalar beta - auto L = get<3>(append<4>(params.withC.problem_shape, _1{})); - if (get<2>(params.dBeta) == repeat_like(L, 0) || size(L) == 1) { // Non-batched - run_voidC = *params.beta_ptr == 0; - } - } - - if (run_voidC) { - return kernel_voidC(params.voidC, smem_buf); - } - else { - return KernelWithC::operator()(params.withC, smem_buf); - } - } - -private: - KernelVoidC kernel_voidC; - -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::conv::kernel - -//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp index 840bae2e..4938767d 100644 --- a/include/cutlass/detail/collective.hpp +++ b/include/cutlass/detail/collective.hpp @@ -81,7 +81,6 @@ is_sm10x_f8f6f4_inputs() { cute::is_same_v || cute::is_same_v - || cute::is_same_v || cute::is_same_v || cute::is_same_v @@ -95,7 +94,6 @@ is_sm10x_f8f6f4_inputs() { cute::is_same_v || cute::is_same_v - || cute::is_same_v || cute::is_same_v || cute::is_same_v @@ -116,7 +114,6 @@ static constexpr bool is_sm10x_f8f6f4_element() { return (cute::is_same_v || cute::is_same_v - || cute::is_same_v || cute::is_same_v || cute::is_same_v @@ -129,7 +126,7 @@ is_sm10x_f8f6f4_element() { template CUTLASS_HOST_DEVICE static constexpr bool -is_sm10x_block_scale_mxf8f6f4_input() { +is_sm10x_mxf8f6f4_input() { // ElementType must be F8, F6, or F4 return ( cute::is_same_v || cute::is_same_v || @@ -144,7 +141,7 @@ is_sm10x_block_scale_mxf8f6f4_input() { template CUTLASS_HOST_DEVICE static constexpr bool -is_sm10x_block_scale_mxf4nvf4_input() { +is_sm10x_mxf4nvf4_input() { // ElementType must be F4 return ( cute::is_same_v || cute::is_same_v @@ -153,12 +150,12 @@ is_sm10x_block_scale_mxf4nvf4_input() { template struct sm10x_block_scale_runtime_input_t { - static constexpr bool IsMxF8F6F4MmaInput = is_sm10x_block_scale_mxf8f6f4_input(); - static constexpr bool IsMxF4NvF4MmaInput = is_sm10x_block_scale_mxf4nvf4_input(); + static constexpr bool IsF8F6F4MmaInput = is_sm10x_mxf8f6f4_input(); + static constexpr bool IsF4MmaInput = is_sm10x_mxf4nvf4_input(); - using Type = cute::conditional_t diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index ad5b5191..c9042351 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -301,7 +301,7 @@ struct LayoutAwareConvertImpl< } } }; - +/* // Specialization for E5M2 -> FP16 with [3120] value order template <> struct LayoutAwareConvertImpl< @@ -343,12 +343,12 @@ struct LayoutAwareConvertImpl< } } }; - +*/ // Specialization for INT8 -> BF16 with [3120] value order template <> struct LayoutAwareConvertImpl< cutlass::int8_t, - cutlass::half_t, + cutlass::bfloat16_t, cute::Layout, cute::Stride<_2,_1>>, cute::Layout<_4> > { @@ -363,9 +363,9 @@ struct LayoutAwareConvertImpl< >& dst) { static_assert(cute::is_same_v && - cute::is_same_v); + cute::is_same_v); using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; + using DstArray = cutlass::Array; using RegArray = cutlass::AlignedArray; auto&& src_reg = cute::recast(src)(0); @@ -403,7 +403,7 @@ struct LayoutAwareConvertImpl< template <> struct LayoutAwareConvertImpl< cutlass::int8_t, - cutlass::bfloat16_t, + cutlass::half_t, cute::Layout, cute::Stride<_2,_1>>, cute::Layout<_4> > { @@ -418,9 +418,9 @@ struct LayoutAwareConvertImpl< >& dst) { static_assert(cute::is_same_v && - cute::is_same_v); + cute::is_same_v); using SrcArray = cutlass::Array; - using DstArray = cutlass::Array; + using DstArray = cutlass::Array; using RegArray = cutlass::AlignedArray; auto&& src_reg = cute::recast(src)(0); diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index 882a6e2f..c3a23387 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -506,7 +506,6 @@ sm100_get_smem_load_op() { template constexpr auto sm100_get_gmem_load_op() { - if constexpr (detail::is_im2col_mode) { return SM90_TMA_LOAD_IM2COL{}; } @@ -519,7 +518,6 @@ sm100_get_gmem_load_op() { template constexpr auto sm100_get_gmem_store_op() { - if constexpr (detail::is_im2col_mode) { return SM90_TMA_STORE_IM2COL{}; } diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 5d9d6817..2759d0c6 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -208,7 +208,6 @@ struct IsThreadEpilogueOpWithElementwiseArguments< ThreadEpilogueOp, cute::void_t> : cute::true_type {}; - // Check if ActivationFn has 'Arguments' type defined template struct sm100_act_has_arguments : cute::false_type {}; @@ -499,7 +498,6 @@ public: using TensorMapStorage = typename EpilogueOp::SharedStorage; using PipelineStorage = typename LoadPipeline::SharedStorage; - // Planar complex kernels have two accumulator copies for the real and imaginary tensors. static constexpr int NumAccumulatorMtxs = Sm100EpilogueOpNumAccumulatorMtxs::value; template diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp index 0b007208..1354349d 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -986,6 +986,314 @@ public: return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state, acc_pipe_consumer_state); } + // API with Global Accumulator in registers for FastFP32 (emulated MMA) kernels. + // The accumulator in TMEM periodically loaded into the registers so that the MMA can clear out the TMEM accumulator + // values for better accuracy. This epilogue accepts the accumulator in registers and take TiledCopy for the + // TMEM->Reg as a parameter to be used in partitioning GMEM tensors C and D. + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class MmaTileMNK, + class TiledMma, + class AccEngine, + class AccLayout, + class TiledCopyT2R, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + MmaTileMNK mma_tile_mnk, + TiledMma tiled_mma, + cute::Tensor& tTR_rAcc, // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + TensorStorage& shared_tensors, + TensorMapD store_tensormap, + TiledCopyT2R tiled_t2r + ) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be Register resident."); + static_assert(rank(AccLayout{}) == 5, "Accumulators must be copy-partitioned: (T2R,T2R_M,T2R_N,EPI_M,EPI_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "CoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + int thread_idx = threadIdx.x % ThreadCount; + int warp_idx = thread_idx / NumThreadsPerWarp; + [[maybe_unused]] int lane_idx = thread_idx % NumThreadsPerWarp; + + auto coord_shape = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N),Int<1>{})); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(cta_tile_mnk)); + Tensor gD = local_tile(mD, take<0,2>(cta_tile_mnk), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // (t)hread-partition for (t)mem to (r)egister copy (tTR_) + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_sD = thread_t2r.partition_D(sD_epi(_,_,_0{})); // (T2R,T2R_M,T2R_N) + + // Allocate D and accumulator registers + Tensor tTR_rD = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); // (EPI_V) + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tTR_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tTR_rC = make_tensor(shape(tTR_sD)); // (T2R,T2R_M,T2R_N) + Tensor tSR_rC = thread_s2r.retile_D(tTR_rC); // (S2R,S2R_M,S2R_N) + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rD = thread_r2s.retile_S(tTR_rD); // (R2S,R2S_M,R2S_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(cta_tile_mnk), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tTR_cD_mn = thread_t2r.partition_D(flat_divide(cD_mn, EpilogueTile{})); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tTR_cD = make_counting_tensor(tTR_cD_mn.layout()); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tTR_cD = make_coord(M,N) - tTR_cD_mn(_0{}); // (m,n) + + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = false; // Register tensors reference T2R copy dst layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_t2r, + cD, + residue_cD, + tTR_cD, + residue_tTR_cD, + tTR_rC, + thread_idx + }; + + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = warp_idx == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + // If TMA store supported async transaction mbarriers we would not need this synchronous release behavior. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = store_pipe_producer_state.count() > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Begin the wait for the producer load results + ConsumerToken load_wait_token{BarrierStatus::WaitDone}; + if (is_producer_load_needed) { + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state); + } + + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // For each epilogue subtile within the CTA tile + CUTLASS_PRAGMA_UNROLL + for (int iter_n = 0; iter_n < size<3>(gD_epi); ++iter_n) { + CUTLASS_PRAGMA_UNROLL + for (int iter_m = 0; iter_m < size<2>(gD_epi); ++iter_m) { + int epi_m = iter_m, epi_n = iter_n; + bool is_first_iteration = iter_m == 0 && iter_n == 0; + bool is_last_iteration = iter_m == size<2>(gD_epi)-1 && iter_n == size<3>(gD_epi)-1; + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state, load_wait_token); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + // Ensure smem loads are complete before reusing smem for mixed types/layouts + if constexpr (ReuseSmemC && not (SmemLayoutC{} == SmemLayoutD{})) { + synchronize(); + } + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + // Let producer load warp know smem buffers are consumed and empty + if constexpr (not ReuseSmemC) { + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + bool issue_smem_store = true; + Tensor tTR_rAcc_epi_tile = tTR_rAcc(_,_,_,epi_m,epi_n); + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc_epi_tile)); // (EPI_V) + + // Vectorized fragment loop with visitor callback entry point + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rD_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + Tensor reduction_buffer = make_tensor(raw_pointer_cast(sD_epi(_,_,store_pipe_producer_state.index()).data()), + make_layout(stride<2>(get_nonswizzle_portion(SmemLayoutD{})), _1{})); + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rD_frg); + + // Copy output tile from register to smem + if (issue_smem_store) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Begin the wait for the next subtile producer load + load_wait_token = load_pipeline.consumer_try_wait(load_wait_state, is_last_iteration); + } + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + template CUTLASS_DEVICE void store_tail( diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 156d3588..fcd9fc56 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -82,11 +82,9 @@ struct FusionOperation { using ElementAmax = void; static constexpr bool IsAbsMaxSupported = false; - using ElementBlockScaleFactor = void; static constexpr int SFVecSize = 0; static constexpr bool IsBlockScaleSupported = false; // Umbrella variable to check BlockScaling support in the epilogues - using GmemLayoutTagScalefactor = void; }; @@ -484,7 +482,6 @@ struct LinCombDeEltActDePerRowBias static constexpr bool IsDePerRowBiasSupported = true; }; - template< int SFVecSize_, class ElementOutput_, diff --git a/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp index 24972141..8ec31ee6 100644 --- a/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp @@ -417,7 +417,6 @@ struct FusionCallbacks< using Impl::Impl; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// // D = alpha * acc + beta * C + per-row bias diff --git a/include/cutlass/exmy_base.h b/include/cutlass/exmy_base.h index 5c4e5460..3215d55a 100644 --- a/include/cutlass/exmy_base.h +++ b/include/cutlass/exmy_base.h @@ -747,7 +747,6 @@ private: src_sign_bit, dst_exponent, dst_mantissa); #endif - // TODO potential narrowing here if (dst_encoding.significand_hidden_bits(dst_mantissa) > 0b1) { // Significant became larger than 01.X...X. Divide significand by 2 and multiply exp by 2 @@ -848,16 +847,13 @@ CUTLASS_CONSTEXPR_IF_CXX17 auto fp_encoding_selector() { return cutlass::detail::FpBitRepresentation{}; } else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E5M2) { // FP8 - // TODO: Not tested. Will be done in another MR return cutlass::detail::FpBitRepresentation{}; } else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::E4M3) { // FP8 - // TODO: Not tested. Will be done in another MR return cutlass::detail::FpBitRepresentation{}; } else if CUTLASS_CONSTEXPR_IF_CXX17 (FpExMyCode == FpEncoding::UE4M3) { // FP8 - // TODO: Not tested. Will be done in another MR return cutlass::detail::FpBitRepresentation{}; } @@ -993,20 +989,16 @@ struct float_exmy_base return f; } - // TODO: Add rounding parameter with a reasonable default CUTLASS_HOST_DEVICE float_exmy_base convert_from_float(float const &flt) const { - // TODO: If we have a cvt instruction specialize in the children structs FP32BitRepresentation::Storage fp32_bits = FP32BitRepresentation::to_bits(flt); float_exmy_base float_exmy; float_exmy.storage = BitRepresentation::convert_from(fp32_bits, FP32BitRepresentation{}); return float_exmy; } - // TODO: Add rounding parameter with a reasonable default CUTLASS_HOST_DEVICE float convert_to_float(float_exmy_base const &x) const { - // TODO: If we have a cvt instruction specialize in the children structs FP32BitRepresentation::Storage fp32_bits; fp32_bits = BitRepresentation::convert_to(x.storage, FP32BitRepresentation{}); return detail::copy_bits(fp32_bits); diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 279c3aa6..a725a889 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -39,8 +39,13 @@ #include #endif #if !defined(__QNX__) +#include +#if defined(_MSC_VER) && defined(CCCL_VERSION) && CCCL_VERSION >= 2008000 +#include +#else #include #endif +#endif #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/uint128.h" diff --git a/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl new file mode 100644 index 00000000..0ddec554 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +// + +// + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template< + int CapacityBytes, + class CtaTileShape_MNK, + class TiledMma, + class KernelScheduleType, + UMMA::Major UmmaMajorA, + int ComplexComponent = 1, + int NumComputeMtxs = 3, + int carveout_bytes +> +constexpr cute::tuple +sm100_compute_stage_count_or_override_fast_fp32(StageCountAutoCarveout stage_count) { + constexpr int CtaM = get<0>(CtaTileShape_MNK{}); + constexpr int CtaN = get<1>(CtaTileShape_MNK{}); + static_assert(CtaN <= 128, "Can't support CtaN>128 tiles"); + constexpr int CtaK = get<2>(CtaTileShape_MNK{}); + using AtomThrID = typename TiledMma::AtomThrID; + // Detect 2x2 TMEM layout + constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN/2 : CtaN; + constexpr int TmemAWordsPerDP = ComplexComponent * NumComputeMtxs * CtaK / 2; + constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v; + constexpr bool IsAComputeinSmem = !IsAComputeinTmem; + constexpr int AccumulatorStageCount = (IsAComputeinTmem) ? (((TmemAccWordsPerDP * ComplexComponent == 128) ? 2 : 3) * ComplexComponent) : (512 / TmemAccWordsPerDP); + + constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32); + + constexpr int TmemInAStageCount_Potential = (IsAComputeinTmem) ? (512 - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000; + + constexpr auto load2transform_pipeline_bytes = sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v * ComplexComponent; + constexpr auto b_bits = cute::sizeof_bits_v * ComplexComponent; + constexpr int ab_stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{})) + + static_cast(load2transform_pipeline_bytes); + + constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage); + constexpr auto a_compute_bits = cute::sizeof_bits_v * ComplexComponent; + constexpr auto b_compute_bits = cute::sizeof_bits_v * ComplexComponent * ComplexComponent; + constexpr int ab_compute_stage_bytes = + cutlass::bits_to_bytes(NumComputeMtxs * a_compute_bits * int(IsAComputeinSmem) * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + // If ACompute is in TMEM, Acompute buffer has 0 bytes. + cutlass::bits_to_bytes(NumComputeMtxs * b_compute_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{})) + + static_cast(transform2mma_pipeline_bytes); + + constexpr int ABComputeStageCount_Potential = SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes); + // The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount + constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential); + + constexpr int SmemCapacityAfterABComputeCarveout = SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes); + // Can we boost the number of buffers for A and B? + constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes; + + static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2, "Not enough SMEM or TMEM capacity for selected tile size"); + return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount); +} + +} // namespace detail + + +// FastFP (9xBF16) MMA kernels builder +template < + class GmemLayoutATag, + int AlignmentA, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, // The Cluster-level TileShape + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + float, // ElementA + GmemLayoutATag, // LayoutA + AlignmentA, + float, // ElementB + GmemLayoutBTag, // LayoutB + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int) + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (not cute::is_tuple::value && not cute::is_tuple::value) && + (cute::is_base_of_v) && + ((sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0) && + ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>> +{ + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + using ElementA = float; + using ElementB = float; + using ElementAMma = cutlass::bfloat16_t; + using ElementBMma = cutlass::bfloat16_t; + static constexpr int ScalingFactor = 8; + + using TiledMma = decltype(detail::sm100_make_trivial_fastFP32_tiled_mma()); + using AtomThrID = typename TiledMma::AtomThrID; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + // Take 3 compute buffers into account for swizzle selection + using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + // Input transform kernel can not use TMA 2SM instructions. + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{}))); + using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomA, SmemLayoutAtomACompute>; + + static constexpr int MMA_M = cute::size<0,0>(MmaShapeA_MK{}); + using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementA>, + cute::conditional_t<(UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v), + cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x, SM100_TMEM_STORE_32dp32b8x>, // TS Implementation + Copy_Atom, ElementA>> // SS Implementation + >; + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + // Input transform kernel can not use TMA 2SM instructions. + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + // Take 3 compute buffers into account for swizzle selection + using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + + using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomB, SmemLayoutAtomBCompute>; + using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementB>, + Copy_Atom, ElementBMma> + >; + + // SmemCarveout + static constexpr int NumBandsToCompute = 5; + static constexpr int AccPromotionInterval = 1; + static constexpr int SchedulerPipelineStageCount = 3; + static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t); + // Tensormap Storage + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( CLCPipelineStorage + + CLCResponseStorage + + CLCThrottlePipelineStorage + + TmemDeallocStorage + + TmemBasePtrsStorage + + TensorMapStorage); + + // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations + static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_fast_fp32< + Sm100ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA>(StageCountType{}); + + static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info); + static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); + static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info); + + using AccumulatorCopyAtom = cute::SM100_TMEM_LOAD_32dp32b32x; + + using DispatchPolicy = cute::conditional_t, + cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount, + Transform2MmaPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + NumBandsToCompute, + ScalingFactor, + AccPromotionInterval, + ClusterShape_MNK, + AccumulatorCopyAtom> + >; + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomPairA, + CopyAtomPairA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomPairB, + CopyAtomPairB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index b4927883..31e3e923 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -71,7 +71,7 @@ template < > constexpr int sm100_compute_stage_count_or_override_blockscaled(StageCountAutoCarveout stage_count) { - // For Mxf8f6f4 sub-bytes, ElementA/B will be passed in as uint8_t + // For MXF8F6F4 MMA, ElementA/B will be passed in as uint8_t // Each stage include (CollectiveMma::SharedStorage) // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) // 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage) @@ -386,7 +386,7 @@ select_instr() { } else if constexpr (( sizeof_bits_v == 4 && (sizeof_bits_v == 6 || sizeof_bits_v == 8)) || ((sizeof_bits_v == 6 || sizeof_bits_v == 8) && sizeof_bits_v == 4)) { - // Fp4 can be mixed with FP6, Fp8 with Mxf8f6f4 only + // Fp4 can be mixed with FP6, Fp8 with MMA.MXF8F6F4 only return detail::blockscaled::BlockScaledInstr::MXF4F6F8; } else if constexpr (sizeof_bits_v == 4 && sizeof_bits_v == 4) { @@ -400,7 +400,7 @@ select_instr() { static_assert( cute::is_same_v && (cute::is_same_v && cute::is_same_v || cute::is_same_v && cute::is_same_v), - "Only MXF4 support with non-TN and Mxf8f6f4"); + "Only MXF4 support with non-TN and MMA.MXF8F6F4."); return detail::blockscaled::BlockScaledInstr::MXF4F6F8; } } @@ -636,7 +636,7 @@ struct CollectiveBuilder< static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8; - static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B()), "Only Mxf8f6f4 supports non-K major inputs"); + static_assert(UseMxf8f6f4 || (cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B()), "Only MMA.MXF8F6F4 supports non-K major inputs"); // Data type used by MMA instruction using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); diff --git a/include/cutlass/gemm/collective/builders/sm100_common.inl b/include/cutlass/gemm/collective/builders/sm100_common.inl index 464ffe89..8e53866a 100644 --- a/include/cutlass/gemm/collective/builders/sm100_common.inl +++ b/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -477,6 +477,94 @@ sm100_make_trivial_tiled_mma() { } } +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + int Scale, + class KernelScheduleType +> +constexpr auto +sm100_make_trivial_fastFP32_tiled_mma() { + // MMA_2SM requested + if constexpr (cute::is_base_of_v ) { + using AtomLayout_MNK = decltype(make_layout(shape_div(ClusterShape_MNK{}, Shape<_2,_1,_1>{}))); + constexpr int M = cute::size<0>(TileShape_MNK{}); + constexpr int N = cute::size<1>(TileShape_MNK{}); + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_2x1SM_TS_SCALED{}); + } + else { // If A needs to be transposed by MMA, fall back to SMEM from A MMA instructions + return make_tiled_mma(cute::SM100_MMA_F16BF16_2x1SM_SS_SCALED{}); + } + } + // MMA_1SM requested + else if constexpr (cute::is_base_of_v ) { + // using AtomLayout_MNK = Layout; + constexpr int M = cute::size<0>(TileShape_MNK{}); + constexpr int N = cute::size<1>(TileShape_MNK{}); + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_TS_SCALED{}); + } + else { // If A needs to be transposed by MMA, fall back to SMEM from A MMA instructions + return make_tiled_mma(cute::SM100_MMA_F16BF16_SS_SCALED{}); + } + } + else if constexpr (cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v) { + // Static cluster + if constexpr (cute::is_static_v) { + // For MMA_2SM we need a cluster shape that is multiple of 2x1 + // and only M=128 and M=256 are supported, otherwise, fall back to MMA_1SM + if constexpr (cute::get<0>(ClusterShape_MNK{}) % 2 == 0 && + (cute::get<0>(TileShape_MNK{}) / cute::get<0>(ClusterShape_MNK{})) % 64 == 0) { + if constexpr (!cute::is_base_of_v) { + return sm100_make_trivial_fastFP32_tiled_mma(); + } + else { + return sm100_make_trivial_fastFP32_tiled_mma(); + } + } + else { + if constexpr (!cute::is_base_of_v) { + return sm100_make_trivial_fastFP32_tiled_mma(); + } + else { + return sm100_make_trivial_fastFP32_tiled_mma(); + } + } + } + // Dynamic cluster shape means we cannot assume we can use 2SM MMA + else { + if constexpr (!cute::is_base_of_v) { + return sm100_make_trivial_fastFP32_tiled_mma(); + } + else { + return sm100_make_trivial_fastFP32_tiled_mma(); + } + } + } + else { + static_assert(cutlass::detail::dependent_false == 0, + "Unsupported policy for SM100 collective builder."); + } +} /** * @brief Check for U4_UNPACK_U8, U6_UNPACK_U8 alignment requirement @@ -547,22 +635,22 @@ template < 8 || cute::sizeof_bits_v < 8; + constexpr bool is_f8f6f4_subbytes = cute::sizeof_bits_v < 8 || cute::sizeof_bits_v < 8; - return ((cute::sizeof_bits_v * AlignmentA) % cutlass::detail::get_input_alignment_bits() == 0) && - ((cute::sizeof_bits_v * AlignmentB) % cutlass::detail::get_input_alignment_bits() == 0); + return ((cute::sizeof_bits_v * AlignmentA) % cutlass::detail::get_input_alignment_bits() == 0) && + ((cute::sizeof_bits_v * AlignmentB) % cutlass::detail::get_input_alignment_bits() == 0); } template constexpr bool sm1xx_blockscaled_gemm_is_aligned() { // Only support blocksscaled gemm alignment check - constexpr bool is_f6f4_subbytes = (cute::sizeof_bits_v < 8 || cute::sizeof_bits_v < 8) && + constexpr bool is_mxf8f6f4_subbytes = (cute::sizeof_bits_v < 8 || cute::sizeof_bits_v < 8) && (cute::is_base_of_v ); - return ((cute::sizeof_bits_v * AlignmentA) % cutlass::detail::get_input_alignment_bits() == 0) && - ((cute::sizeof_bits_v * AlignmentB) % cutlass::detail::get_input_alignment_bits() == 0); + return ((cute::sizeof_bits_v * AlignmentA) % cutlass::detail::get_input_alignment_bits() == 0) && + ((cute::sizeof_bits_v * AlignmentB) % cutlass::detail::get_input_alignment_bits() == 0); } } // namespace detail diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index c937619c..87b62de4 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -82,7 +82,7 @@ template< int carveout_bytes> constexpr int sm100_compute_stage_count_or_override(StageCountAutoCarveout stage_count) { - // For F8F6F4 sub-bytes, ElementA/B will be passed in as uint8_t + // For F8/F6/F4 sub-bytes, ElementA/B will be passed in as uint8_t // For Planar Complex, ElementA/B will be passed in as cutlass::complex // Each stage include (CollectiveMma::SharedStorage) // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) @@ -253,7 +253,9 @@ struct CollectiveBuilder< static constexpr uint32_t TotalTmemRows = 128; static constexpr uint32_t Sm100TmemCapacityColumns = 512; static constexpr uint32_t TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; - static constexpr uint32_t AccumulatorPipelineStageCount = TotalTmem / (cute::size<0>(CtaTileShape_MNK{}) * cute::size<1>(CtaTileShape_MNK{})); + static constexpr uint32_t AccumulatorPipelineStageCount = (is_2sm || (!is_2sm && size(shape<0,0>(MmaShapeA_MK{}) > 64))) ? + TotalTmem / (cute::size<0>(CtaTileShape_MNK{}) * cute::size<1>(CtaTileShape_MNK{})) + : (Sm100TmemCapacityColumns / cute::size<1>(CtaTileShape_MNK{})) * 2; // 1SM MMA_M = 64 case static_assert(AccumulatorPipelineStageCount > 0, "Accumulator pipeline stage count must be positive. This error probably means that TileShape_MNK and/or TiledMma::ThrLayoutVMNK are wrong."); // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 8b2452af..4209fd87 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -261,8 +261,9 @@ struct CollectiveBuilder< using SmemLayoutAtomB = decltype(detail::ss_smem_selector< GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - static constexpr int Sm90ReducedSmemCapacityBytes = - detail::sm90_smem_capacity_bytes; + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); + static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); @@ -368,7 +369,12 @@ public: return t; } else { + if constexpr (cute::is_pointer_v) { + return &cute::stride(*t); + } + else { return cute::stride(t); + } } } @@ -441,14 +447,20 @@ public: static constexpr int Sm90ReducedSmemCapacityBytes = detail::sm90_smem_capacity_bytes - KernelSmemCarveout; static constexpr int PipelineStages = IsMixedInput ? + ( IsArrayOfPointersGemm ? + detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) : detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}) + ) : detail::compute_stage_count_or_override(StageCountType{}); using DispatchPolicy = cute::conditional_t - , MainloopSm90TmaGmmaRmemAWarpSpecialized>; + cute::conditional_t, + MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>, + MainloopSm90TmaGmmaRmemAWarpSpecialized>; using SmemCopyAtomA = cute::conditional_t>; using SmemCopyAtomB = cute::conditional_t, void>; diff --git a/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl b/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl index 744402e5..d6702930 100644 --- a/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl +++ b/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl @@ -71,15 +71,15 @@ struct Sm90GemmSparseConfig { using ElementEMmaSparsity = Int; // MMA type - static constexpr bool IsQmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + static constexpr bool IsF8 = cute::is_same_v && ElementAMmaSparsity{} == _2{} || cute::is_same_v && ElementAMmaSparsity{} == _2{}; - static constexpr bool IsImma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + static constexpr bool IsI8 = cute::is_same_v && ElementAMmaSparsity{} == _2{} || cute::is_same_v && ElementAMmaSparsity{} == _2{}; - static constexpr bool IsHmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + static constexpr bool IsF16BF16 = cute::is_same_v && ElementAMmaSparsity{} == _2{} || cute::is_same_v && ElementAMmaSparsity{} == _2{}; - static constexpr bool IsTfmma = cute::is_same_v && ElementAMmaSparsity{} == _2{} || + static constexpr bool IsTF32 = cute::is_same_v && ElementAMmaSparsity{} == _2{} || cute::is_same_v && ElementAMmaSparsity{} == _2{}; - static_assert(int(IsQmma) + int(IsImma) + int(IsHmma) + int(IsTfmma) == 1, "Ambigious Input Type Config (failed to choose MMA type)"); + static_assert(int(IsF8) + int(IsI8) + int(IsF16BF16) + int(IsTF32) == 1, "Ambigious Input Type Config (failed to choose MMA type)"); // Number of ElementARaw stored in ElementAMmaRaw. For Hopper this is always 1. using ElemsARawPerElementAMmaRaw = _1; @@ -89,12 +89,12 @@ struct Sm90GemmSparseConfig { static_assert(ElementASparsity{} == _2{}, "ElementASparsity must be 2 for Hopper Sparse Gemm"); // Logical/Physical ElementA per Chunk - using LogicalElemsAPerChunk = conditional_t; + using LogicalElemsAPerChunk = conditional_t; using PhysicalElemsAPerChunk = Int; // Metadata Bits using ElementEBitsPerChunk = _4; - using ElementEBitsPerElementAMma = cute::conditional_t; + using ElementEBitsPerElementAMma = cute::conditional_t; // Metadata Layout. Unit in corresbonding logical elements. // Basic metadata block is (16,64) for 8-bit, (16,32) for 16-bit, (16,16) for 32-bit data types. @@ -114,8 +114,8 @@ struct Sm90GemmSparseConfig { using TensorEAtom_8bit = decltype(make_ordered_layout(Shape<_64,MinTileShapeK>{}, Step < _1, _0>{})); - using TensorEAtom = cute::conditional_t<(IsQmma || IsImma), TensorEAtom_8bit, - cute::conditional_t>; // Logical elems that construct the atomK for tensorE/A. diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index c54cf907..9623900b 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -40,8 +40,9 @@ #include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" #include "cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl" #if !defined(__CUDACC_RTC__) -#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl" -#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" #endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index a57e5a08..f6a0dcb3 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -46,11 +46,16 @@ #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_sparse_mma_tma_gmma_ss_warpspecialized_fp8.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" -#if !defined(__CUDACC_RTC__) -#include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" -#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp" +#include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp" + #include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" #endif // !defined(__CUDACC_RTC__) diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp index 0b0e2e3a..808e4495 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -682,11 +682,11 @@ struct CollectiveMma< auto mSFB_nkl = [=](){ if constexpr (IsCtaN192) { Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); - auto x = stride<0,2>(mSFB_tmp); - auto y = ceil_div(shape<0,2>(mSFB_tmp), 4); - auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), shape<0,1>(mSFB_tmp), + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); - auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), stride<0,1>(mSFB_tmp), + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); } diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp index b28a3075..65718878 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -717,11 +717,11 @@ struct CollectiveMma< auto mSFB_nkl = [=](){ if constexpr (IsCtaN192) { Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB_)); - auto x = stride<0,2>(mSFB_tmp); - auto y = ceil_div(shape<0,2>(mSFB_tmp), 4); - auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), shape<0,1>(mSFB_tmp), + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); - auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), stride<0,1>(mSFB_tmp), + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); } diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp new file mode 100644 index 00000000..b623ca97 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp @@ -0,0 +1,1126 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/cluster.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for FastF32 Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int NumBandsToCompute_, + int ScalingFactor_, + int AccPromotionInterval_, + class AccumulatorCopyAtom_, + class ClusterShape, + class TileShape_, + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + float, + StrideA_, + float, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = float; + using PackedElementA = float2; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementAMma = typename TiledMma::ValTypeA; + using PackedElementAMma = uint32_t; + using ElementB = float; + using PackedElementB = float2; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using ElementBMma = typename TiledMma::ValTypeB; + using PackedElementBMma = uint32_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(cute::is_same_v, "Input type A should be float"); + static_assert(cute::is_same_v, "Input type B should be float"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 3; + constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; + constexpr static int ScalingFactor = DispatchPolicy::ScalingFactor; + constexpr static int AccPromotionInterval = DispatchPolicy::AccPromotionInterval; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}) / DispatchPolicy::AccPromotionInterval; + constexpr static int NumBandsMax = 5; + static_assert(NumBandsToCompute <= NumBandsMax && NumBandsToCompute >= 3, "NumBandsToCompute should be less than maximum number of bands"); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + static_assert((NumBandsToCompute == 5 || NumBandsToCompute == 4 || NumBandsToCompute == 3), + "9xBF16 with 5/4/3 Bands are supported"); + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(append(CtaShapeA_MK{}, Int{}), Int{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + append(append(CtaShapeB_NK{}, Int{}), Int{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + }; + + struct TensorStorageTransformedAinSmem { + alignas(1024) cute::ArrayEngine> smem_ACompute; + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + union TensorStorageTransformedAinTmem { + alignas(1024) cute::ArrayEngine smem_ACompute; // No smem_ACompute + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A{nullptr}; + StrideA dA{}; + ElementB const** ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + cute::TmaDescriptor* tensormaps; + ElementA const** ptr_A; + ElementB const** ptr_B; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Tensor shapes for Ptr-Array are initialized correctly here. + auto [M,N,K,mock_L] = problem_shape.get_host_problem_shape(0); + // Batches/Groups are managed by using appropriate pointers to input matrices + mock_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_first_batch = nullptr; + ElementB const* ptr_B_first_batch = nullptr; + + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(M,K,mock_L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(N,K,mock_L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + reinterpret_cast(args.ptr_B) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto [M,N,K,L] = problem_shape.get_host_problem_shape(0); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, + input_tensormaps] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage, + int32_t const sm_count, int32_t const sm_idx) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init(params, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementA and ElementB types should be the same."); + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAdA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM or TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAdACompute, + unused_tBgB, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_temp = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrB_temp = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrBCompute = make_tensor(tBsB(_,_,_,_,0).shape()); + + auto tArA_x2 = recast>(tArA); + auto tArA_temp_x2 = recast>(tArA_temp); + auto tArACompute_x2 = recast>(tArACompute); + + auto tBrB_x2 = recast>(tBrB); + auto tBrB_temp_x2 = recast>(tBrB_temp); + auto tBrBCompute_x2 = recast>(tBrBCompute); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input B matrix from SMEM + copy(AutoVectorizingCopy{}, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tBrB_x2, tBrBCompute_x2, cutlass::NumericArrayConverter::convert); + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tBrBCompute_x2, tBrB_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tBrB_x2, tBrB_temp_x2, tBrB_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tBrB_x2, tBrB_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tArA_x2, tArACompute_x2, cutlass::NumericArrayConverter::convert); + copy(dst_copy_A, tArACompute, tAdACompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tArACompute_x2, tArA_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tArA_x2, tArA_temp_x2, tArA_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tArA_x2, tArA_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + return cute::make_tuple(reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0,0).layout()); + + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + gB_nkl, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, NumComputeMtxs, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + using ZeroScaler = cute::integral_constant; + using Scaler = cute::integral_constant; + + int remaining_accum_promotions = k_tile_count * StagesPerTile; + uint32_t mma2accum_skip_wait = (remaining_accum_promotions <= 0); + auto mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block += DispatchPolicy::AccPromotionInterval, --remaining_accum_promotions) { + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state, mma2accum_flag); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + + ++mma2accum_pipeline_producer_state; + mma2accum_skip_wait = (remaining_accum_promotions <= 1); + mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + auto tCrA0 = tCrA(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrA1 = tCrA(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrA2 = tCrA(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrB1 = tCrB(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrB2 = tCrB(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + // MMA instructions Emulation + auto accumulate = UMMA::ScaleOut::Zero; + + // First set of GEMMs that we need to perform for each band are unrolled to set compile-time constant + // scaling parameter. Scaled GEMM operations are only needed for the first MMA operation of each band. + + // Band 5 + if constexpr (NumBandsToCompute == 5) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[2] + accumulate = UMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[2] + } + } + // Band 4 + if constexpr (NumBandsToCompute >= 4) { + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA1(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[1]*B[2] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[2]*B[1] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[1]*B[2] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[2]*B[1] + } + } + // Band 3 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[0] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[2] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[0] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[2] + } + // Band 2 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[1]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[1]*B[0] + } + // Band 1 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[0] + } + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + auto get_tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + // Obtain a single accumulator + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + // Create the TMEM copy for single EpilogueTile. + // Note that EpilogueTile = CtaTile for NoSmem epilogue + auto tiled_t2r = make_tmem_copy(tmem_cp_atom, tAcc_epi(_,_,_0{},_0{})); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(tAcc_epi); + Tensor tTR_rAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Apply epilogue subtiling to bulk accumulator + // We need to tile the whole bulk_tmem allocation with EpilogueTile. + // The accumulation should be aware of the AccumulatorPipelineStages + Tensor tBulkAcc_epi = flat_divide(accumulators(make_coord(_,_),_0{},_0{}, _), EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,PIPE) + Tensor tTR_tBulkAcc = thread_t2r.partition_S(tBulkAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N,PIPE) + return cute::make_tuple(tiled_t2r, thread_t2r, tTR_tBulkAcc, tTR_rAcc, tTR_rGlobAcc); + } + + template + CUTLASS_DEVICE auto + accum(cute::tuple accum_inputs, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_consumer_state, + int k_tile_count) { + auto [tiled_t2r, thread_t2r, tTR_tBulkAcc, + tTR_rAcc, tTR_rGlobAcc] = accum_inputs; + + + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Clear the global accumulator + CUTE_UNROLL + for (int i = 0; i 0; --k_tile_count) { + // The stage is limited to a CTA tile + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block>{}); + + cutlass::arch::fence_view_async_tmem_load(); // Need a fence bw TMEM_LOAD and arrive + mma2accum_pipeline.consumer_release(mma2accum_pipeline_consumer_state); + + ++mma2accum_pipeline_consumer_state; + skip_wait = ((k_tile_count <= 1) && (k_block >= (StagesPerTile-1))); + mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); + } + } + return cute::make_tuple(mma2accum_pipeline_consumer_state, tTR_rGlobAcc); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init(Params const& mainloop_params, int32_t const sm_count, int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to gmem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gA_tensormap = make_tensor(tma_desc_a, Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(tma_desc_b, Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(gA_tensormap)); + copy(recast(pB_tensormap), recast(gB_tensormap)); + } + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Bringing tensormaps to smem (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_fetch_to_smem( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) const { + Tensor gA_tensormap = make_tensor(make_gmem_ptr(get<0>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor gB_tensormap = make_tensor(make_gmem_ptr(get<1>(input_tensormaps)), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormap.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(gA_tensormap), recast(sA_tensormap)); + copy(recast(gB_tensormap), recast(sB_tensormap)); + + cp_async_fence(); + cp_async_wait<0>(); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormap.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormap, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + int32_t next_batch, + uint32_t lane_predicate) { + if (lane_predicate) { + // Bringing tensormaps to smem + tensormaps_fetch_to_smem(shared_tensormap, input_tensormaps); + + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormap, mainloop_params, next_batch); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormap, + cute::tuple const& input_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormap.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormap.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + +private: + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp new file mode 100644 index 00000000..705562f8 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp @@ -0,0 +1,1018 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/cluster.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +namespace detail { +template +struct CollectiveMmaEmulatedLayoutAtomType { + using InputLayoutAtom = InputLayoutAtom_; + using ComputeLayoutAtom = ComputeLayoutAtom_; +}; + +template +struct CollectiveMmaEmulatedCopyType { + using InputCopyAtom = InputCopyAtom_; + using ComputeCopyAtom = ComputeCopyAtom_; +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for FastF32 Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + int NumBandsToCompute_, + int ScalingFactor_, + int AccPromotionInterval_, + class AccumulatorCopyAtom_, + class ClusterShape, + class TileShape_, + class StrideA_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>, + TileShape_, + float, + StrideA_, + float, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ + // + // Type Aliases + // + + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedFastF32< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + NumBandsToCompute_, + ScalingFactor_, + AccPromotionInterval_, + ClusterShape, + AccumulatorCopyAtom_>; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementA = float; + using PackedElementA = float2; + using StrideA = StrideA_; + using ElementAMma = typename TiledMma::ValTypeA; + using PackedElementAMma = uint32_t; + using ElementB = float; + using PackedElementB = float2; + using StrideB = StrideB_; + using ElementBMma = typename TiledMma::ValTypeB; + using PackedElementBMma = uint32_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(cute::is_same_v, "Input type A should be float"); + static_assert(cute::is_same_v, "Input type B should be float"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + static_assert(cute::is_same_v, "Compute type A should be cutlass::bfloat16_t"); + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; + + // Get the Algorithm parameters + constexpr static int NumComputeMtxs = 3; + constexpr static int NumBandsToCompute = DispatchPolicy::NumBandsToCompute; + constexpr static int ScalingFactor = DispatchPolicy::ScalingFactor; + constexpr static int AccPromotionInterval = DispatchPolicy::AccPromotionInterval; + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}) / DispatchPolicy::AccPromotionInterval; + constexpr static int NumBandsMax = 5; + static_assert(NumBandsToCompute <= NumBandsMax && NumBandsToCompute >= 3, "NumBandsToCompute should be less than maximum number of bands"); + + // Copy atom for Accumulator + using AccumulatorCopyAtom = typename DispatchPolicy::AccumulatorCopyAtom; + + static_assert((NumBandsToCompute == 5 || NumBandsToCompute == 4 || NumBandsToCompute == 3), + "9xBF16 with 5/4/3 Bands are supported"); + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(append(CtaShapeA_MK{}, Int{}), Int{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutBCompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomBCompute{}, + append(append(CtaShapeB_NK{}, Int{}), Int{}))); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyB - invalid TMA copy atom specified."); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + struct TensorStorageUntransformed { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + }; + + struct TensorStorageTransformedAinSmem { + alignas(1024) cute::ArrayEngine> smem_ACompute; + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + union TensorStorageTransformedAinTmem { + alignas(1024) cute::ArrayEngine smem_ACompute; // No smem_ACompute + alignas(1024) cute::ArrayEngine> smem_BCompute; + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * size<2>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * size<2>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + Load2TransformPipeline pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK mainloop_load2xform_pipeline_state for _writing_ + pipeline.producer_acquire(load2xform_pipeline_state, pipeline_flag); + int write_stage = load2xform_pipeline_state.index(); + + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop_pipe + ++load2xform_pipeline_state; + skip_wait = (k_tile_count <= 1); + pipeline_flag = pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + ++k_tile_iter; + } + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b); // multicast masks + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class GTensorB, class SrcTensorB, class DstTensorB + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementA and ElementB types should be the same."); + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAdA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM or TMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tBsB : (Copy,#Copy),MMA_Rest,MMA_N_Rest,MMA_K_Rest, NumComputeMtxs, SmemStages (In SMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAdACompute, + unused_tBgB, tBsB, tBsBCompute] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArA_temp = make_tensor(tAsA(_,_,_,_,0).shape()); + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + + auto tBrB = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrB_temp = make_tensor(tBsB(_,_,_,_,0).shape()); + auto tBrBCompute = make_tensor(tBsB(_,_,_,_,0).shape()); + + auto tArA_x2 = recast>(tArA); + auto tArA_temp_x2 = recast>(tArA_temp); + auto tArACompute_x2 = recast>(tArACompute); + + auto tBrB_x2 = recast>(tBrB); + auto tBrB_temp_x2 = recast>(tBrB_temp); + auto tBrBCompute_x2 = recast>(tBrBCompute); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Copy the input B matrix from SMEM + copy(AutoVectorizingCopy{}, tBsB(_,_,_,_,load2transform_consumer_index), tBrB); + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tBrB_x2, tBrBCompute_x2, cutlass::NumericArrayConverter::convert); + copy(AutoVectorizingCopy{}, tBrBCompute, tBsBCompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tBrBCompute_x2, tBrB_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tBrB_x2, tBrB_temp_x2, tBrB_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tBrB_x2, tBrB_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + CUTE_UNROLL + for (int comp_mtx_index = 0; comp_mtx_index < NumComputeMtxs; ++comp_mtx_index) { + // Convert from fp32 -> bf16 + cute::transform(tArA_x2, tArACompute_x2, cutlass::NumericArrayConverter::convert); + copy(dst_copy_A, tArACompute, tAdACompute(_,_,_,_,comp_mtx_index,transform2mma_producer_index)); + + // if it is not the last compute matrix, scale and substract + if (comp_mtx_index < NumComputeMtxs - 1) { + // Convert from bf16 -> fp32 to substract + cute::transform(tArACompute_x2, tArA_temp_x2, cutlass::NumericArrayConverter::convert); + cute::transform(tArA_x2, tArA_temp_x2, tArA_x2, cutlass::minus>{}); + if constexpr (DispatchPolicy::ScalingFactor != 0) { + cute::transform(tArA_x2, tArA_x2, cutlass::scale>{(1 << DispatchPolicy::ScalingFactor)}); + } + } + } + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sB_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor sB = as_position_independent_swizzle_tensor(sB_orig); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); // ((128,16),m,k,PIPE) + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto reg2tmem_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0,0)); + auto thr_reg2tmem_tiled_copy = reg2tmem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2tmem_tiled_copy.partition_S(tensor_input2x); + auto partitioned_tensor_compute = thr_reg2tmem_tiled_copy.partition_D(fragment_compute); + return cute::make_tuple(reg2tmem_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto reg2smem_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0,0).layout()); + + auto thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(tensor_input); + auto partitioned_tensor_compute = thr_reg2smem_tiled_copy.partition_D(tensor_compute_ind_sw); + + return cute::make_tuple(AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + auto [dst_copy_B, tBsB, tBsBCompute] = + setup_copy_ops(sB, InputCopyAtomB{}, sBCompute, [&](auto &arg) {return TiledMma::make_fragment_B(arg);}, ComputeCopyAtomB{}); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + gB_nkl, tBsB, tBsBCompute); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + + // tCrA : (MMA), MMA_M, MMA_K, NumComputeMtxs, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, NumComputeMtxs, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + using ZeroScaler = cute::integral_constant; + using Scaler = cute::integral_constant; + + int remaining_accum_promotions = k_tile_count * StagesPerTile; + uint32_t mma2accum_skip_wait = (remaining_accum_promotions <= 0); + auto mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block += DispatchPolicy::AccPromotionInterval, --remaining_accum_promotions) { + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state, mma2accum_flag); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + + ++mma2accum_pipeline_producer_state; + mma2accum_skip_wait = (remaining_accum_promotions <= 1); + mma2accum_flag = mma2accum_pipeline.producer_try_acquire(mma2accum_pipeline_producer_state, mma2accum_skip_wait); + + auto tCrA0 = tCrA(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrA1 = tCrA(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrA2 = tCrA(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + auto tCrB0 = tCrB(_,_,_,0,transform2mma_pipeline_consumer_state_index); + auto tCrB1 = tCrB(_,_,_,1,transform2mma_pipeline_consumer_state_index); + auto tCrB2 = tCrB(_,_,_,2,transform2mma_pipeline_consumer_state_index); + + // MMA instructions Emulation + auto accumulate = UMMA::ScaleOut::Zero; + // First set of GEMMs that we need to perform for each band are unrolled to set compile-time constant + // scaling parameter. Scaled GEMM operations are only needed for the first MMA operation of each band. + + // Band 5 + if constexpr (NumBandsToCompute == 5) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[2] + accumulate = UMMA::ScaleOut::One; + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[2] + } + } + // Band 4 + if constexpr (NumBandsToCompute >= 4) { + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA1(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[1]*B[2] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[2]*B[1] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[1]*B[2] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[2]*B[1] + } + } + // Band 3 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB2(_,_,k_block), tCtC); // A[2]*B[0] + accumulate = UMMA::ScaleOut::One; + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[2] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB2(_,_,k_block+s), tCtC); // A[2]*B[0] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[1]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA2(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[2] + } + // Band 2 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB1(_,_,k_block), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[1]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB1(_,_,k_block+s), tCtC); // A[0]*B[1] + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA1(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[1]*B[0] + } + // Band 1 + cute::gemm(tiled_mma.with(accumulate, Scaler{}), tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + CUTLASS_PRAGMA_UNROLL + for (int s = 1; s < DispatchPolicy::AccPromotionInterval; s++) { + cute::gemm(tiled_mma.with(accumulate, ZeroScaler{}), tCrA0(_,_,k_block+s), tCrB0(_,_,k_block+s), tCtC); // A[0]*B[0] + } + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + } + + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + return cute::make_tuple(curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + auto get_tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sBCompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_BCompute.begin()), SmemLayoutBCompute{}); + Tensor tCrB = tiled_mma.make_fragment_B(sBCompute); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + // Obtain a single accumulator + Tensor tAcc = tensor<0>(accumulators(_,_,_,_0{})); + // Apply epilogue subtiling + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + // Create the TMEM copy for single EpilogueTile. + // Note that EpilogueTile = CtaTile for NoSmem epilogue + auto tiled_t2r = make_tmem_copy(tmem_cp_atom, tAcc_epi(_,_,_0{},_0{})); + auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + Tensor tTR_gC = thread_t2r.partition_D(tAcc_epi); + Tensor tTR_rAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rGlobAcc = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Apply epilogue subtiling to bulk accumulator + // We need to tile the whole bulk_tmem allocation with EpilogueTile. + // The accumulation should be aware of the AccumulatorPipelineStages + Tensor tBulkAcc_epi = flat_divide(accumulators(make_coord(_,_),_0{},_0{}, _), EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,PIPE) + Tensor tTR_tBulkAcc = thread_t2r.partition_S(tBulkAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N,PIPE) + return cute::make_tuple(tiled_t2r, thread_t2r, tTR_tBulkAcc, tTR_rAcc, tTR_rGlobAcc); + } + + template + CUTLASS_DEVICE auto + accum(cute::tuple accum_inputs, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_consumer_state, + int k_tile_count) { + auto [tiled_t2r, thread_t2r, tTR_tBulkAcc, + tTR_rAcc, tTR_rGlobAcc] = accum_inputs; + + + Tensor tTR_rAcc_float2 = recast>(tTR_rAcc); // (T2R/2,T2R_M,T2R_N) + Tensor tTR_rGlobAcc_float2 = recast>(tTR_rGlobAcc); // (T2R/2,T2R_M,T2R_N) + + // Clear the global accumulator + CUTE_UNROLL + for (int i = 0; i 0; --k_tile_count) { + // The stage is limited to a CTA tile + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block>{}); + + cutlass::arch::fence_view_async_tmem_load(); // Need a fence bw TMEM_LOAD and arrive + mma2accum_pipeline.consumer_release(mma2accum_pipeline_consumer_state); + + ++mma2accum_pipeline_consumer_state; + skip_wait = ((k_tile_count <= 1) && (k_block >= (StagesPerTile-1))); + mma2accum_flag = mma2accum_pipeline.consumer_try_wait(mma2accum_pipeline_consumer_state, skip_wait); + } + } + return cute::make_tuple(mma2accum_pipeline_consumer_state, tTR_rGlobAcc); + } + +private: + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp index ad1e0525..22a7d4be 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -240,7 +240,6 @@ public: // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. // We must also handle updating the pipeline transaction bytes on the fly. - // NOTE: Deleting this assertion without required changes will cause the code to hang. static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); private: @@ -490,8 +489,6 @@ public: : args_setup(args.ptr_A, args.ptr_B); } else if constexpr (ModeHasScales) { - // NOTE: fix chunk wise scaling - //auto scale_k = (K + args.chunk_size - 1) / args.chunk_size; auto scale_k = 1; ElementScale const* ptr_S = reinterpret_cast(args.ptr_S); StrideScale dS{}; @@ -998,7 +995,6 @@ public: Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - // NOTE: Check this when applying swizzling PR on top of GGMD Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); @@ -1049,7 +1045,6 @@ public: Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 0, smem_pipe_read.index()); - // NOTE: Check this when applying swizzling PR on top of GGMD Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); Utils::dequantize_A_kblock(tCrA_load, tCrA_mma, partitioned_extra_info, 0); @@ -1248,7 +1243,6 @@ public: if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { NonVoidElementScale const* ptr_S = nullptr; - // NOTE: figure out chunk wise scaling. auto scale_k = (K + mainloop_params.chunk_size - 1) / mainloop_params.chunk_size; auto scale_k = 1; Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_scale, tensor_scale, @@ -1256,7 +1250,6 @@ public: } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { ElementZero const* ptr_Z = nullptr; - // NOTE: figure out chunk wise scaling. auto scale_k = (K + mainloop_params.chunk_size - 1) / mainloop_params.chunk_size; auto scale_k = 1; Tensor tensor_zero = make_tensor(detail::get_logical_ptr(ptr_Z), make_shape(M,scale_k,Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_zero, tensor_zero, diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index f95b4f9e..e06ead97 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -531,7 +531,7 @@ struct CollectiveMma< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); - Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) @@ -557,7 +557,6 @@ struct CollectiveMma< PipelineState smem_pipe_release = smem_pipe_read; // Per block scale values for operand A and B - using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 2cc64ba7..155d023d 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -351,6 +351,23 @@ struct MainloopSm90TmaGmmaWarpSpecializedSparseFP8 : MainloopSm90TmaGmmaWarpSpecializedSparse { }; +// Mixed precision version n-buffer in rmem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule for Ptr-Array and Grouped Gemm +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative +> +struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; + static_assert( + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); +}; + template< int SchedulerPipelineStageCount_, @@ -373,6 +390,16 @@ struct KernelTmaWarpSpecializedBlockScaledSm100 final { +// InputTransform GEMM +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedInputTransformSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // Ptr-Array Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms template< int SchedulerPipelineStageCount_, @@ -393,6 +420,15 @@ struct KernelPtrArrayTmaWarpSpecializedBlockScaledSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +// Ptr-Array InputTransform GEMM +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelPtrArrayTmaWarpSpecializedInputTransformSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; ////////////////////////////////////////////////////////////////////////////// @@ -401,32 +437,67 @@ struct KernelPtrArrayTmaWarpSpecializedBlockScaledSm100 final { // Collective Builder Tag Property // +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// +// SM100 Dispatch Policies +// +/////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Base Dispatch Policies struct KernelSchedule1Sm {}; struct KernelSchedule2Sm {}; struct KernelScheduleSm100 {}; -struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; - -struct KernelScheduleBlockScaledGemmSm100 : KernelScheduleSm100 {}; -struct KernelScheduleMxNvf4Sm100 : KernelScheduleBlockScaledGemmSm100 {}; -struct KernelScheduleMxf8f6f4Sm100 : KernelScheduleBlockScaledGemmSm100 {}; - -struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {}; -struct KernelSchedulePtrArrayBlockScaledGemmSm100 : KernelScheduleBlockScaledGemmSm100 {}; -struct KernelSchedulePtrArrayMxNvf4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {}; -struct KernelSchedulePtrArrayMxf8f6f4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {}; - - -// -// Collective Builder Tag -// Only used in CollectiveBuilder -// +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 Dense GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// +struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy // Dense GEMM: Specialize for 1SM vs 2SM -struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; -struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; +struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder +struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder +// Dense GEMM + (Ptr Array or Group GEMM) +struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {}; +// Ptr-Array Dense GEMM: Specialize for 1SM vs 2SM +struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {}; +struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {}; +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 Planar Complex GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// +struct KernelScheduleSm100PlanarComplexGemm : KernelScheduleSm100{}; +// Planar Complex GEMM: Specialize for 1SM vs 2SM +struct KernelTmaWarpSpecialized1SmPlanarComplexSm100 final : KernelSchedule1Sm, KernelScheduleSm100PlanarComplexGemm { }; +struct KernelTmaWarpSpecialized2SmPlanarComplexSm100 final : KernelSchedule2Sm, KernelScheduleSm100PlanarComplexGemm { }; +// Planar Complex GEMM + (Ptr Array or Group GEMM) +struct KernelScheduleSm100PtrArrayPlanarComplexGemm : KernelScheduleSm100PlanarComplexGemm {}; +struct KernelPtrArrayTmaWarpSpecialized1SmPlanarComplexSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayPlanarComplexGemm {}; +struct KernelPtrArrayTmaWarpSpecialized2SmPlanarComplexSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayPlanarComplexGemm {}; +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 FastF32 (9xBF16) GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// +struct KernelScheduleSm100FastFP32Gemm : KernelScheduleSm100 {}; +struct KernelTmaWarpSpecializedFastFP32SmemSm100 : KernelScheduleSm100FastFP32Gemm { }; +// Dispatch policies without smem load the A operand from tmem +struct KernelTmaWarpSpecialized1SmFastFP32Sm100 final : KernelSchedule1Sm, KernelScheduleSm100FastFP32Gemm { }; +struct KernelTmaWarpSpecialized2SmFastFP32Sm100 final : KernelSchedule2Sm, KernelScheduleSm100FastFP32Gemm { }; +// Dispatch policies with smem load the A operand from smem +struct KernelTmaWarpSpecialized1SmFastFP32SmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedFastFP32SmemSm100 { }; +struct KernelTmaWarpSpecialized2SmFastFP32SmemSm100 final : KernelSchedule2Sm, KernelTmaWarpSpecializedFastFP32SmemSm100 { }; +// Ptr-Array Transform GEMM: Specialize for 1SM vs 2SM FastF32 GEMM +struct KernelScheduleSm100PtrArrayFastFP32Gemm : KernelScheduleSm100FastFP32Gemm {}; +struct KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 : KernelScheduleSm100PtrArrayFastFP32Gemm { }; +struct KernelPtrArrayTmaWarpSpecialized1SmFastFP32Sm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayFastFP32Gemm { }; +struct KernelPtrArrayTmaWarpSpecialized2SmFastFP32Sm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayFastFP32Gemm { }; +struct KernelPtrArrayTmaWarpSpecialized1SmFastFP32SmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 { }; +struct KernelPtrArrayTmaWarpSpecialized2SmFastFP32SmemSm100 final : KernelSchedule2Sm, KernelTmaWarpSpecializedPtrArrayFastFP32SmemSm100 { }; +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 BlockScaled Dense GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// +struct KernelScheduleBlockScaledGemmSm100 : KernelScheduleSm100 {}; +struct KernelScheduleMxNvf4Sm100 : KernelScheduleBlockScaledGemmSm100 {}; +struct KernelScheduleMxf8f6f4Sm100 : KernelScheduleBlockScaledGemmSm100 {}; // Block Scaled Dense GEMM: Specialize for instruction type, scale factor vector size, and 1SM vs. 2SM struct KernelTmaWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 { }; struct KernelTmaWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelScheduleBlockScaledGemmSm100 { }; @@ -436,13 +507,10 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelSch struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { }; struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { }; struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { }; - - -// Ptr-Array Dense GEMM: Specialize for 1SM vs 2SM -struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {}; -struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {}; - - +// BlockScaled Dense GEMM + (Ptr Array or Group GEMM) +struct KernelSchedulePtrArrayBlockScaledGemmSm100 : KernelScheduleBlockScaledGemmSm100 {}; +struct KernelSchedulePtrArrayMxNvf4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {}; +struct KernelSchedulePtrArrayMxf8f6f4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {}; // Ptr-Array Block Scaled Dense GEMM: Specialize for instruction type, scale factor vector size, and 1SM vs. 2SM struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayBlockScaledGemmSm100 { }; struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayBlockScaledGemmSm100 { }; @@ -454,6 +522,7 @@ struct KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1S struct KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { }; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, @@ -488,6 +557,55 @@ struct MainloopSm100TmaUmmaWarpSpecializedBlockScaled { +// n-buffer in smem, pipelined with Blackwell Fast FP32 kernel with UMMA (HwScaled) and TMA, +// Warp specialized dynamic schedule +template< + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Load2TransformPipelineStageCount_, + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Transform2MmaPipelineStageCount_, + // TileScheduler pipeline depth + int SchedulerPipelineStageCount_, + // Accmulator pipeline depth + int AccumulatorPipelineStageCount_, + // Number of MMA Bands to be computed in a single FastF32 MMA operation. + // For BF16 emulation, we have 3 compute matrices, with 9 MMAs forming 5 bands. + // We can eliminate bands 4 and/or 5 (up to last 3 MMA operations). + // Valid values are 3, 4, 5 + int NumBandsToCompute_, + // Scaling factor for decomposed matrices (2^ScalingFactor) + // 8 for BF16, 11 for TF32 + int ScalingFactor_, + // Number of UMMA instructions emulated a single stage + // Ex: Staged16 has 1 FastF32 MMA per stage + // Should be smaller than K-mode of a single ClusterTile + int AccPromotionInterval_, + // ClusterShape for the kernel + class ClusterShape_ = Shape<_1,_1,_1>, + // The TMEM_LOAD atom to be used for loading local accumulator + // from TMEM to registers + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x +> +struct MainloopSm100TmaUmmaWarpSpecializedFastF32 { + constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_; + constexpr static int Transform2MmaPipelineStageCount = Transform2MmaPipelineStageCount_; + constexpr static int NumBandsToCompute = NumBandsToCompute_; + constexpr static int ScalingFactor = ScalingFactor_; + constexpr static int AccPromotionInterval = AccPromotionInterval_; + constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::FastF32; + using ClusterShape = ClusterShape_; + using AccumulatorCopyAtom = AccumulatorCopyAtom_; + using ArchTag = arch::Sm100; + using Schedule = KernelTmaWarpSpecializedInputTransformSm100; + + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = Load2TransformPipelineStageCount; +}; + + + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, @@ -520,6 +638,55 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled { +// n-buffer in smem, pipelined with Blackwell Fast FP32 kernel with UMMA (HwScaled) and TMA, +// Warp specialized dynamic schedule +template< + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Load2TransformPipelineStageCount_, + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Transform2MmaPipelineStageCount_, + // TileScheduler pipeline depth + int SchedulerPipelineStageCount_, + // Accmulator pipeline depth + int AccumulatorPipelineStageCount_, + // Number of MMA Bands to be computed in a single FastF32 MMA operation. + // For BF16 emulation, we have 3 compute matrices, with 9 MMAs forming 5 bands. + // We can eliminate bands 4 and/or 5 (up to last 3 MMA operations). + // Valid values are 3, 4, 5 + int NumBandsToCompute_, + // Scaling factor for decomposed matrices (2^ScalingFactor) + // 8 for BF16, 11 for TF32 + int ScalingFactor_, + // Number of UMMA instructions emulated a single stage + // Ex: Staged16 has 1 FastF32 MMA per stage + // Should be smaller than K-mode of a single ClusterTile + int AccPromotionInterval_, + // ClusterShape for the kernel + class ClusterShape_ = Shape<_1,_1,_1>, + // The TMEM_LOAD atom to be used for loading local accumulator + // from TMEM to registers + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x +> +struct MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32 { + constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_; + constexpr static int Transform2MmaPipelineStageCount = Transform2MmaPipelineStageCount_; + constexpr static int NumBandsToCompute = NumBandsToCompute_; + constexpr static int ScalingFactor = ScalingFactor_; + constexpr static int AccPromotionInterval = AccPromotionInterval_; + constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::FastF32; + using ClusterShape = ClusterShape_; + using AccumulatorCopyAtom = AccumulatorCopyAtom_; + using ArchTag = arch::Sm100; + using Schedule = KernelPtrArrayTmaWarpSpecializedInputTransformSm100; + + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = Load2TransformPipelineStageCount; +}; + + + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 50245571..77b2e1ea 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -63,6 +63,10 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + // TileID scheduler + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveMainloop::NumAccumThreads; // 4 warps + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumTransformationThreads = CollectiveMainloop::NumTransformationThreads; // 4 warps + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + + NumEpilogueThreads + NumTransformationThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr cutlass::gemm::detail::KernelInputTransformType InputTransformType = DispatchPolicy::InputTransformType; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Transfer registers from regular warps to Accum warps + static constexpr uint32_t GenericRegisterRequirement = 152; + static constexpr uint32_t AccumRegisterRequirement = 200; + + // Pipeline and pipeline state types + using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; + using Load2TransformPipelineState = typename CollectiveMainloop::Load2TransformPipelineState; + + using Transform2MmaPipeline = typename CollectiveMainloop::Transform2MmaPipeline; + using Transform2MmaPipelineState = typename CollectiveMainloop::Transform2MmaPipelineState; + + using Mma2AccumPipeline = typename CollectiveMainloop::Mma2AccumPipeline; + using Mma2AccumPipelineState = typename CollectiveMainloop::Mma2AccumPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = cutlass::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorMapStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + alignas(128) EpilogueTensorMapStorage epilogue; + alignas(128) MainloopTensorMapStorage mainloop; + } tensormaps; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + // NOTE: MMA must be on the 0th thread of the warp-group, so make sure pipeline leader is on MainloopLoad warp + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4, + // Transformation starts at 256 thread alignment + Transformation = 8 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t transformation = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + ProblemShape problem_shapes = args.problem_shape; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Mainloop + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Mainloop + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + return TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape.get_host_problem_shape(), + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info + ); +} + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + auto problem_shape = params.problem_shape; + + // Account for multiple epilogue and transformation warps + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::Transformation) ? WarpCategory::Epilogue + : WarpCategory::Transformation; + int thread_idx = int(threadIdx.x); + int thread_idx_in_warp = thread_idx % 32; + uint32_t lane_predicate = cute::elect_one_sync(); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + bool is_first_cta_in_cluster = (cta_rank_in_cluster == 0); + bool is_mma_leader_cta = (cta_rank_in_cluster % size<0>(TiledMma{}) == 0); + // Even if this variable is unused, shape_div still performs useful compile-time checks. + [[maybe_unused]] auto mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && (is_first_cta_in_cluster), // sched + (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::Transformation) // transformation + }; + + // MainloopLoad <--> Transformation Pipeline + typename Load2TransformPipeline::Params load2transform_pipeline_params; + if (warp_category == WarpCategory::MainloopLoad) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Transformation) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Consumer; + } + load2transform_pipeline_params.is_leader = (thread_idx_in_warp == 0); + load2transform_pipeline_params.num_consumers = NumTransformationThreads; + load2transform_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + load2transform_pipeline_params.initializing_warp = 0; + Load2TransformPipeline load2transform_pipeline(shared_storage.pipelines.mainloop.load2transform_pipeline, + load2transform_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Load2TransformPipelineState load2transform_pipeline_consumer_state; + Load2TransformPipelineState load2transform_pipeline_producer_state = cutlass::make_producer_start_state(); + + // Transformation <--> MMA pipeline + typename Transform2MmaPipeline::Params transform2mma_pipeline_params; + if (warp_category == WarpCategory::Transformation) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::MMA) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Consumer; + } + transform2mma_pipeline_params.consumer_arv_count = 1; + transform2mma_pipeline_params.producer_arv_count = size(AtomThrShapeMNK{}) * NumTransformationThreads; + transform2mma_pipeline_params.initializing_warp = 2; + Transform2MmaPipeline transform2mma_pipeline(shared_storage.pipelines.mainloop.transform2mma_pipeline, + transform2mma_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Transform2MmaPipelineState transform2mma_pipeline_consumer_state; + Transform2MmaPipelineState transform2mma_pipeline_producer_state = cutlass::make_producer_start_state(); + + // MMA <--> Accumulator pipeline + typename Mma2AccumPipeline::Params mma2accum_pipeline_params; + if (warp_category == WarpCategory::MMA) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Epilogue) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Consumer; + } + mma2accum_pipeline_params.producer_arv_count = 1; + mma2accum_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + mma2accum_pipeline_params.initializing_warp = 6; + Mma2AccumPipeline mma2accum_pipeline(shared_storage.pipelines.mainloop.mma2accum_pipeline, + mma2accum_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Mma2AccumPipelineState mma2accum_pipeline_consumer_state; + Mma2AccumPipelineState mma2accum_pipeline_producer_state = cutlass::make_producer_start_state(); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = 1; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // CLC pipeline + // Operates Scheduling Warp <--> All Warps + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + + NumMMAThreads + NumTransformationThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + CLCPipelineState clc_pipeline_consumer_state; + CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + TmemAllocator tmem_allocator{}; + + // Sync allocation status between transform, MMA, and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumTransformationThreads + NumMMAThreads + NumEpilogueThreads, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopLoadThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0) + + NumTransformationThreads); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + load2transform_pipeline.init_masks(cluster_shape, block_id_in_cluster); + transform2mma_pipeline.init_masks(cluster_shape); + mma2accum_pipeline.init_masks(cluster_shape); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + + // Allocate accumulators + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + + // NOTE: we can assume the tmem buf starts at zero since we allocate all tmem in this kernel + auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + Int{})); + + // Tile transform inputs now to get the k tile count + auto transform_inputs = collective_mainloop.transform_init(params.mainloop, problem_shape_MNKL, bulk_tmem, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(transform_inputs); + + // Synchronization call. Blocks until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + params.hw_info.sm_count, static_cast(cutlass::arch::SmId())); + Tensor gA_mkl = get<0>(load_inputs); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = get(load_inputs); + + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to + // that args value being device-only. + bool did_batch_change = true; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + do { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + if (did_batch_change) { + collective_mainloop.tensormaps_perform_update( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + curr_batch, + lane_predicate + ); + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (i.e. it's aligned) + collective_mainloop.tensormaps_cp_fence_release(shared_storage.tensormaps.mainloop, input_tensormaps); + } + + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(Load2TransformPipeline::Stages, k_tile_count); + + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + collective_mainloop.tensormaps_fence_acquire(input_tensormaps); + } + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + if (lane_predicate) { + auto [load2transform_pipeline_producer_state_next, k_tile_iter_next] = collective_mainloop.load( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [load2transform_pipeline_producer_state_next_, unused_] = collective_mainloop.load( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next_; + } + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + work_tile_info = next_work_tile_info; + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); + } while (work_tile_info.is_valid()); + if (lane_predicate) { + load2transform_pipeline.producer_tail(load2transform_pipeline_producer_state); + } + + } + + else if (is_participant.transformation) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + auto [load2transform_pipeline_consumer_state_next, transform2mma_pipeline_producer_state_next] = collective_mainloop.transform( + load2transform_pipeline, + load2transform_pipeline_consumer_state, + transform2mma_pipeline, + transform2mma_pipeline_producer_state, + bulk_tmem, + transform_inputs, + k_tile_iter, k_tile_count + ); + transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state_next; + load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state_next; + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + } while (work_tile_info.is_valid()); + + transform2mma_pipeline.producer_tail(transform2mma_pipeline_producer_state); + } + + else if (is_participant.sched) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipeline_producer_state = scheduler.advance_to_next_work( + clc_pipeline, + clc_pipeline_producer_state + ); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipeline_producer_state); + } + } + + else if (is_participant.mma) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Allocate all tmem + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem.data() = tmem_base_ptr; + + auto mma_input_operands = collective_mainloop.mma_init(bulk_tmem, shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (is_mma_leader_cta) { + auto [transform2mma_pipeline_consumer_state_next, mma2accum_pipeline_producer_state_next] = collective_mainloop.mma( + transform2mma_pipeline, + transform2mma_pipeline_consumer_state, + mma2accum_pipeline, + mma2accum_pipeline_producer_state, + bulk_tmem, + mma_input_operands, + k_tile_count + ); + // Advance the mm2accum pipe + transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state_next; + mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state_next; + } + } while (work_tile_info.is_valid()); + + // leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + mma2accum_pipeline.producer_tail(mma2accum_pipeline_producer_state); + } + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Signal to peer MMA that stage can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + // Tmem deallocation sequence + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + // Fetch a copy of tensormaps for the CTA from Params + auto epi_load_tensormap = get<0>(collective_epilogue.load_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, static_cast(cutlass::arch::SmId()))); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape, + curr_batch + ); + } + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get current work tile and fetch next work tile + __syncwarp(); + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + cute::make_tuple(epi_load_tensormap, did_batch_change) + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Register reconfiguration + arch::warpgroup_reg_alloc(); + + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem.data() = tmem_base_ptr; + + auto accum_inputs = collective_mainloop.accum_init(bulk_tmem, typename CollectiveEpilogue::CopyOpT2R{}, typename CollectiveEpilogue::EpilogueTile{}); + bool do_tail_store = false; + auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); + // Fetch a copy of tensormaps for the CTA from Params + auto epi_store_tensormap = get<0>(collective_epilogue.store_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, static_cast(cutlass::arch::SmId()))); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape, + curr_batch + ); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + if constexpr (InputTransformType == cutlass::gemm::detail::KernelInputTransformType::FastF32) { + auto [mma2accum_pipeline_consumer_state_next,tTR_rGlobAcc] = collective_mainloop.accum( + accum_inputs, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + k_tile_count); + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_fence_acquire(epi_store_tensormap); + } + auto [load_state_next, store_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + tTR_rGlobAcc, + shared_storage.tensors.epilogue, + epi_store_tensormap, + get<0>(accum_inputs) // tiled_t2r + ); + + do_tail_store |= TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + // Advance the mm2accum pipe + mma2accum_pipeline_consumer_state = mma2accum_pipeline_consumer_state_next; + } + // Complex kernels use a collective epilogue + else { + mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state); + + // Accumulators (real and imag) + Tensor accumulators = bulk_tmem(_,_,_,_,mma2accum_pipeline_consumer_state.index()); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + // + // Epilogue and write to gD + // + // The tile scheduler and current work are passed into the collective epilogue to + // support fixup operations needed by split-/stream-K. These operations are pushed + // to the collective layer so that they can reuse the TMEM -> RF copy performed + // at the collective layer. + auto [mma2accum_pipeline_state_next] = collective_epilogue( + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + accumulators, + shared_storage.tensors.epilogue + ); + // Advance the mm2accum pipe + mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp new file mode 100644 index 00000000..4e1d2930 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp @@ -0,0 +1,1065 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr bool IsComplex = DispatchPolicy::InputTransformType == cutlass::gemm::detail::KernelInputTransformType::InterleavedComplexTF32; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + // TileID scheduler + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveMainloop::NumAccumThreads; // 4 warps + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumTransformationThreads = CollectiveMainloop::NumTransformationThreads; // 4 warps + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + + NumEpilogueThreads + NumTransformationThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr cutlass::gemm::detail::KernelInputTransformType InputTransformType = DispatchPolicy::InputTransformType; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Transfer registers from regular warps to Accum warps + static constexpr uint32_t GenericRegisterRequirement = 152; + static constexpr uint32_t AccumRegisterRequirement = 200; + + // Pipeline and pipeline state types + using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; + using Load2TransformPipelineState = typename CollectiveMainloop::Load2TransformPipelineState; + + using Transform2MmaPipeline = typename CollectiveMainloop::Transform2MmaPipeline; + using Transform2MmaPipelineState = typename CollectiveMainloop::Transform2MmaPipelineState; + + using Mma2AccumPipeline = typename CollectiveMainloop::Mma2AccumPipeline; + using Mma2AccumPipelineState = typename CollectiveMainloop::Mma2AccumPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = cutlass::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4, + // Transformation starts at 256 thread alignment + Transformation = 8 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t transformation = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + auto blk_shape = CtaShape_MNK{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for multiple epilogue and transformation warps + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::Transformation) ? WarpCategory::Epilogue + : WarpCategory::Transformation; + int thread_idx = int(threadIdx.x); + int thread_idx_in_warp = thread_idx % 32; + uint32_t lane_predicate = cute::elect_one_sync(); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + bool is_first_cta_in_cluster = (cta_rank_in_cluster == 0); + bool is_mma_leader_cta = (cta_rank_in_cluster % size<0>(TiledMma{}) == 0); + // Even if this variable is unused, shape_div still performs useful compile-time checks. + [[maybe_unused]] auto mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && (is_first_cta_in_cluster), // sched + (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::Transformation) // transformation + }; + + // MainloopLoad <--> Transformation Pipeline + typename Load2TransformPipeline::Params load2transform_pipeline_params; + if (warp_category == WarpCategory::MainloopLoad) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Transformation) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Consumer; + } + load2transform_pipeline_params.is_leader = (thread_idx_in_warp == 0); + load2transform_pipeline_params.num_consumers = NumTransformationThreads; + load2transform_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + load2transform_pipeline_params.initializing_warp = 0; + Load2TransformPipeline load2transform_pipeline(shared_storage.pipelines.mainloop.load2transform_pipeline, + load2transform_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Load2TransformPipelineState load2transform_pipeline_consumer_state; + Load2TransformPipelineState load2transform_pipeline_producer_state = cutlass::make_producer_start_state(); + + // Transformation <--> MMA pipeline + typename Transform2MmaPipeline::Params transform2mma_pipeline_params; + if (warp_category == WarpCategory::Transformation) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::MMA) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Consumer; + } + transform2mma_pipeline_params.consumer_arv_count = 1; + transform2mma_pipeline_params.producer_arv_count = size(AtomThrShapeMNK{}) * NumTransformationThreads; + transform2mma_pipeline_params.initializing_warp = 2; + Transform2MmaPipeline transform2mma_pipeline(shared_storage.pipelines.mainloop.transform2mma_pipeline, + transform2mma_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Transform2MmaPipelineState transform2mma_pipeline_consumer_state; + Transform2MmaPipelineState transform2mma_pipeline_producer_state = cutlass::make_producer_start_state(); + + // MMA <--> Accumulator pipeline + typename Mma2AccumPipeline::Params mma2accum_pipeline_params; + if (warp_category == WarpCategory::MMA) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Epilogue) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Consumer; + } + mma2accum_pipeline_params.producer_arv_count = 1; + mma2accum_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + mma2accum_pipeline_params.initializing_warp = 6; + Mma2AccumPipeline mma2accum_pipeline(shared_storage.pipelines.mainloop.mma2accum_pipeline, + mma2accum_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Mma2AccumPipelineState mma2accum_pipeline_consumer_state; + Mma2AccumPipelineState mma2accum_pipeline_producer_state = cutlass::make_producer_start_state(); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = 1; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // CLC pipeline + // Operates Scheduling Warp <--> All Warps + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + + NumMMAThreads + NumTransformationThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + CLCPipelineState clc_pipeline_consumer_state; + CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between transform, MMA, and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumTransformationThreads + NumMMAThreads + NumEpilogueThreads, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopLoadThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0) + + NumTransformationThreads); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + load2transform_pipeline.init_masks(cluster_shape, block_id_in_cluster); + transform2mma_pipeline.init_masks(cluster_shape); + mma2accum_pipeline.init_masks(cluster_shape); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // Allocate accumulators + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + Int{})); + + // Tile transform inputs now to get the k tile count + auto transform_inputs = collective_mainloop.transform_init(params.mainloop, problem_shape_MNKL, bulk_tmem, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(transform_inputs); + + // Synchronization call. Blocks until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + do { + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(Load2TransformPipeline::Stages, k_tile_count); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + if (lane_predicate) { + auto [load2transform_pipeline_producer_state_next, k_tile_iter_next] = collective_mainloop.load( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [load2transform_pipeline_producer_state_next_, unused_] = collective_mainloop.load( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next_; + } + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + if (lane_predicate) { + load2transform_pipeline.producer_tail(load2transform_pipeline_producer_state); + } + + } + + else if (is_participant.sched) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipeline_producer_state = scheduler.advance_to_next_work( + clc_pipeline, + clc_pipeline_producer_state + ); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipeline_producer_state); + } + } + + else if (is_participant.transformation) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + auto [load2transform_pipeline_consumer_state_next, transform2mma_pipeline_producer_state_next] = collective_mainloop.transform( + load2transform_pipeline, + load2transform_pipeline_consumer_state, + transform2mma_pipeline, + transform2mma_pipeline_producer_state, + bulk_tmem, + transform_inputs, + k_tile_iter, k_tile_count + ); + transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state_next; + load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state_next; + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + } while (work_tile_info.is_valid()); + + transform2mma_pipeline.producer_tail(transform2mma_pipeline_producer_state); + } + + else if (is_participant.mma) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + + auto mma_input_operands = collective_mainloop.mma_init(bulk_tmem, shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (is_mma_leader_cta) { + auto [transform2mma_pipeline_consumer_state_next, mma2accum_pipeline_producer_state_next] = collective_mainloop.mma( + transform2mma_pipeline, + transform2mma_pipeline_consumer_state, + mma2accum_pipeline, + mma2accum_pipeline_producer_state, + bulk_tmem, + mma_input_operands, + k_tile_count + ); + // Advance the mm2accum pipe + transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state_next; + mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state_next; + } + } while (work_tile_info.is_valid()); + + // leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + mma2accum_pipeline.producer_tail(mma2accum_pipeline_producer_state); + } + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Register reconfiguration + arch::warpgroup_reg_alloc(); + + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + + auto accum_inputs = collective_mainloop.accum_init(bulk_tmem, typename CollectiveEpilogue::CopyOpT2R{}, typename CollectiveEpilogue::EpilogueTile{}); + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + if constexpr (InputTransformType == cutlass::gemm::detail::KernelInputTransformType::FastF32) { + auto [mma2accum_pipeline_consumer_state_next,tTR_rGlobAcc] = collective_mainloop.accum( + accum_inputs, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + k_tile_count); + + mma2accum_pipeline_consumer_state_next = scheduler.template fixup( + TiledMma{}, + work_tile_info, + tTR_rGlobAcc, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state_next, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + tTR_rGlobAcc, + shared_storage.tensors.epilogue, + get<0>(accum_inputs) // tiled_t2r + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + do_tail_store = true; + } + + // Advance the mm2accum pipe + mma2accum_pipeline_consumer_state = mma2accum_pipeline_consumer_state_next; + } + // Complex kernels use a collective epilogue + else { + mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state); + + // Accumulators (real and imag) + Tensor accumulators = bulk_tmem(_,_,_,_,mma2accum_pipeline_consumer_state.index()); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + mma2accum_pipeline_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulators, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [mma2accum_pipeline_state_next] = collective_epilogue( + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + accumulators, + shared_storage.tensors.epilogue + ); + // Advance the mm2accum pipe + mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; + } + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index b65c45c2..7c375747 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -120,7 +120,7 @@ public: typename detail::TileSchedulerSelector< GroupScheduler, ArchTag, TileShape, ClusterShape, - 2, // Default unused parameter - SchedulerPipelineStageCoun + 2, // Default unused parameter - SchedulerPipelineStageCount ProblemShape>::Scheduler, typename detail::TileSchedulerSelector< void, ArchTag, TileShape, ClusterShape>::Scheduler>; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 6311c601..5bd5196f 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -120,7 +120,7 @@ public: typename detail::TileSchedulerSelector< GroupScheduler, ArchTag, TileShape, ClusterShape, - 2, // Default unused parameter - SchedulerPipelineStageCoun + 2, // Default unused parameter - SchedulerPipelineStageCount ProblemShape>::Scheduler, typename detail::TileSchedulerSelector< void, ArchTag, TileShape, ClusterShape>::Scheduler>; diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index fde8eb07..33c7d8ec 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -1095,6 +1095,34 @@ struct NumericArrayConverter <= Array, round to nearest with min/max saturation +template <> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest_satfinite; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned d; + + asm("cvt.rn.satfinite.bf16x2.f32 %0, %1, %2;\n" : "=r"(d) : "f"(source[1]), "f"(source[0]) ); + + return reinterpret_cast(d); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + + /// Partial specialization for Array <= Array template < int N, @@ -2382,7 +2410,6 @@ struct NumericArrayConverterPacked4Element { }; - ///////////////////////////////////////////////////////////////////////////////////////////////// // // Partial specializations for Array <=> Array @@ -3579,7 +3606,6 @@ template < > struct NumericArrayConverter : public PackedNumericArrayConverter {}; - /// Partial specialization for Array <= Array template < typename T, diff --git a/include/cutlass/pipeline/sm100_pipeline.hpp b/include/cutlass/pipeline/sm100_pipeline.hpp index e5ac47a8..ef2e80bf 100644 --- a/include/cutlass/pipeline/sm100_pipeline.hpp +++ b/include/cutlass/pipeline/sm100_pipeline.hpp @@ -275,6 +275,189 @@ private: } }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA (producer) Transform (consumer) Async Pipeline +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template < + int Stages_, + class AtomThrShape_MNK_ = Shape<_1,_1,_1> +> +class PipelineTmaTransformAsync { +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; +private: + using Impl = PipelineTmaAsync; +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + // Constructor + template + CUTLASS_DEVICE + PipelineTmaTransformAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , full_barrier_ptr_(&storage.full_barrier_[0]) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + // Helper function to initialize barriers + template + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static_assert(IsDynamicCluster or ((cute::size<0>(cluster_shape) % cute::size<0>(atom_thr_shape) == 0) && + (cute::size<1>(cluster_shape) % cute::size<1>(atom_thr_shape) == 0))); + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + // Logic to optimally schedule Empty Arrives + // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) + int warp_idx = canonical_warp_idx_sync(); + int thread_idx = threadIdx.x; + auto cluster_size = cute::size(cluster_shape); + + // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) + if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { + auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warpgroup(thread_idx % NumThreadsPerWarpGroup, warp_idx); + is_signaling_thread_ = is_signaling_thread; + dst_blockid_ = dst_blockid; + } + else if (params_.num_consumers == 32) { + auto [is_signaling_thread, dst_blockid] = detail::spread_arrivals_to_warp(thread_idx % 32); + is_signaling_thread_ = is_signaling_thread; + dst_blockid_ = dst_blockid; + } + else { + is_signaling_thread_ = 0; + #ifndef NDEBUG + asm volatile ("brkpt;\n" ::); + #endif + } + + // STEP 2: Find if this dst block-id needs an arrival for this problem + is_signaling_thread_ &= dst_blockid_ < cluster_size; + is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id_in_cluster, cluster_shape); + } + } + + template + CUTLASS_DEVICE + bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { + return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) || + ( + ((dst_block_id / cute::size<0>(cluster_shape)) == block_id.y) + // If we are in the same cluster column and using 2CTA MMA, only odd or only even CTAs sync with each other + && ((dst_block_id % cute::size<0>(cluster_shape)) % cute::size<0>(AtomThrShape_MNK{}) == + block_id.x % cute::size<0>(AtomThrShape_MNK{})) + )); + } + + //////////////////// + // Producer APIs + //////////////////// + CUTLASS_DEVICE + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + return impl_.producer_try_acquire(state, skip_wait); + } + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_commit(PipelineState state, uint32_t bytes) { + impl_.producer_commit(state, bytes); + } + + // Prevents early exit of producer blocks in Cluster. + // This should be called once before kernel exits. + CUTLASS_DEVICE + void producer_tail(PipelineState state) { + impl_.producer_tail(state); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + //////////////////// + // Consumer APIs + //////////////////// + CUTLASS_DEVICE + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + return impl_.consumer_try_wait(state, skip_wait); + } + + CUTLASS_DEVICE + ConsumerToken consumer_test_wait(PipelineState state, uint32_t skip_wait = false) { + return impl_.consumer_test_wait(state, skip_wait); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state) { + impl_.consumer_wait(state); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state, uint32_t skip = false) { + detail::pipeline_check_is_consumer(params_.role); + empty_barrier_ptr_[state.index()].arrive(dst_blockid_, is_signaling_thread_ & (!skip)); + } + +private: + Impl impl_; + uint32_t dst_blockid_ = 0; + uint32_t is_signaling_thread_ = 0; + FullBarrier *full_barrier_ptr_ = nullptr; + EmptyBarrier *empty_barrier_ptr_ = nullptr; + Params params_; +}; + /////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -391,7 +574,6 @@ public: } } - // !!!!!! I DONT LIKE THIS MCAST BASED CONSTRUCTOR SPECIALIZATION. THIS VARIABLE NEVER CHANGES AT RUNTIME. template CUTLASS_DEVICE PipelineTmaUmmaAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 65b77904..68bdb26e 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -71,7 +71,7 @@ bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { if (a == b) { return true; } - else if (a == zero || b == zero || diff < nonzero_floor) { + else if (a == zero || b == zero || (abs_A + abs_B) < nonzero_floor) { return diff < epsilon * nonzero_floor; } diff --git a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp index 38a39740..59f4ba1f 100644 --- a/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +++ b/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp @@ -285,7 +285,7 @@ private: uint8_t storage_ = 0b0000; }; - using MetadataOneChunk = cute::conditional_t; diff --git a/media/docs/profiler.md b/media/docs/profiler.md index 6383f979..736344b4 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -24,6 +24,17 @@ $ make cutlass_profiler -j Enabling the unity build places multiple kernel instances in one compilation unit, thereby reducing size of the compiled binary and avoiding linker limitations on some platforms. +The CUTLASS Profiler sources are stored in: + +```bash +tools/ + profiler/ +``` + +# Emitting kernels via `emit_kernel_listing.py` + +We provide a Python script `emit_kernel_listing.py` that allows a user to selectively test a subset of profiler-based kernels stamped out in `generator.py`. A unique benefit to generate kernels and test via this script is that it can feed a series of runtime arguments, such as different `M`/`N`/`K` and `alpha`/`beta`, to each kernel, instead of relying on a single default value. It also properly generates runtime datatype and cluster shapes for certain kernels to help reduce the generated kernel count and accordingly the total compilation time. An interested user may refer to [emit_kernel_listing.py](../../python/cutlass_library/emit_kernel_listing.py) for details. To enable this new feature, a user should add `-DCUTLASS_BUILD_FOR_PROFILER_REGRESSIONS=ON` when building CUTLASS profiler. + ### Instantiating more kernels with Hopper With Hopper (SM90), you will need to use an additional flag, `CUTLASS_LIBRARY_INSTANTIATION_LEVEL`, in order to instantiate all possible combinations, @@ -84,12 +95,30 @@ An instantiation level `500`, which is padded to `0500`, thus indicates: - **Cluster Sizes**: At level 5, allowing for clusters with 1, 2, 4, 8, or 16 CTAs. - **Schedule Pruning**: At level 0, where pruning is applied according to the existing `generator.py` behavior. -The CUTLASS Profiler sources are stored in: +### Mixed input data type kernels for Hopper -```bash -tools/ - profiler/ -``` +With Hopper (SM90), the kernel generator will generate the following combinations of mixed input data types ("mixed dtype"): + +| dtype(A) | dtype(B) | +| -------- | ---------- | +| e4m3 | f16, bf16 | +| e5m2 | f16, bf16 | +| int8 | f16, bf16 | +| uint8 | f16, bf16 | +| int4 | f16, bf16 | +| int4 | e4m3, e5m2 | +| uint4 | f16, bf16 | +| int2 | f16, bf16 | +| uint2 | f16, bf16 | + +For each mixed dtype kernel, the kernel generator will generate combinations of three different running modes: +* Convert-only +* Scale-only +* Scale-with-zero-point-shifting + +For {4-bits-dtype, 8-bits-dtype} x 16-bits-dtype, the kernel generator will further generate kernels using shuffled layouts for the narrow data type matrix, which may have a better performance compared to its non-shuffle counter parts. + +### CUTLASS Profiler usage The CUTLASS Profiler usage statement may be obtained by executing `cutlass_profiler --help` and appears as follows. ```bash @@ -175,6 +204,8 @@ Profiling: --warmup-iterations= Number of iterations to execute each kernel prior to profiling (default: 10). + --use-cuda-graphs= If true, kernels are launched in a CUDA graph. Useful when the kernel launch time is a bottleneck. + --sleep-duration= Number of ms to sleep between profiling periods (ms). --profiling-enabled= If true, profiling is actually conducted. @@ -289,6 +320,8 @@ GEMM [enum] --raster_order={heuristic|H|along_m|M|along_n|N} If supported by kernel, sets the tile raster direction [int] --swizzle_size={1,2,4,8} If supported by kernel, sets the 2D tile swizzle extent (In Hopper, other values will be rounded down to the nearest supported value) [int] --use_pdl,--use-pdl Use PDL (true, false) + [enum] --runtime_input_datatype_a Runtime data type for A matrix, narrow-precision only (e4m3, e5m2, e3m2, e2m3, e2m1) + [enum] --runtime_input_datatype_b Runtime data type for B matrix, narrow-precision only (e4m3, e5m2, e3m2, e2m3, e2m1) Examples: @@ -329,7 +362,11 @@ Profile when execution is performed on device 0 and the C tensor is located on a The format of tensor argument is followed by `:`. The type could be `f32` as 32-bit floating point, `s8` as 8-bit signed integer, etc. The available types can be referred to the `NumericTypeID_enumerants` in [util.cu](tools/library/src/util.cu). The layout could be `row` or `column`. -CUTLASS 3.x kernels for Hopper and Blackwell also support a new feature called programatic dependent launch (PDL). This can be enabled with `--use-pdl`, and can overlap the epilogue of the prior kernel with the prologue of the next kernel. This can effectively hide kernel prologues. Using PDL can improve performance for back to back GEMMs. See [dependent kernel launch](dependent_kernel_launch.md) for more information. +In addition to encoded data types, CUTLASS profiler allows non-encoded generic data types, namely `f8`, `f6`, and `f4`, with corresponding encoding specified through GEMM input argument: `--runtime_input_datatype_a` and `--runtime_input_datatype_b`. Currently, six encoding schemes are supported: `e4m3`, `e5m2`, `e3m2`, `e2m3`, and `e2m1`. + +Cluster shapes can be statically set to `Shape;` and specified via runtime arguments: `cluster_m`, `cluster_n` and `cluster_k` in CUTLASS profiler. One may refer to our CUTLASS Example [73_blackwell_gemm_flexible_cluster](../../examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for more details of the this feature. + +CUTLASS 3.x kernels for Hopper and Blackwell also support a new feature called programatic dependent launch (PDL). This can be enabled with `--use-pdl`, and can overlap the epilogue of the prior kernel with the prologue of the next kernel. This can effectively hide kernel prologues. Using PDL can improve performance for back to back GEMMs. See [dependent kernel launch](dependent_kernel_launch.md) for more information. CUDA graphs can also be used (`--use-cuda-graphs`) with PDL to ensure that smaller kernels are enqueued back-to-back on a stream. ## Example CUDA Core GEMM Operation @@ -444,7 +481,7 @@ To best illustrate this naming convention, we will walk through the meaning of e in a GEMM kernel used by the profiler: ``` -cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f32_128x128x64_2x1x1_0_ntn_align8 +cutlass3x_sm90_tensorop_s64x128x16gemm_f16_f16_f32_f16_f32_{optional-mixed-dtype-config}_128x128x64_2x1x1_0_ntn_align8 ``` The components within this name are as follows: @@ -457,6 +494,7 @@ The components within this name are as follows: (as opposed to `h`, which indicates half precision) * `64x128x16gemm`: indicates that the shape of the Tensor Core instruction being used (MxNxK) is 64x128x16 * `f16_f16_f32_f16_f16`: indicates that the data types for operands A, B, Accumulator, C and D (in that order). +* `optional-mixed-dtype-config`: optional, will be empty if this is not a mixed dtype kernel. For mixed dtype kernels, it contains `_cvt`, `_scl`, `_sclzr`, respectively, for convert-only, scale-only, scale-with-zero-point running modes. It further contains `_shfl` if the kernel uses a shuffled layout for the narrow data type input matrix. * `128x128x64`: indicates that the thread block shape used in the GEMM (MxNxK) is 128x128x64 * `2x1x1`: indicates that the cluster shape being used is 2x1x1 * `0`: indicates that the kernel uses the CollectiveBuilder's automatic stage calculation to determine the diff --git a/media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg b/media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg old mode 100755 new mode 100644 index 0f1441e5..aaa386ac --- a/media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg +++ b/media/images/cutlass-3.8-blackwell-gemm-peak-performance.svg @@ -1 +1 @@ -0102030405060708090100MXFP4MXFP8S8F8F16BF16TF32% throughput of theoretical peak @ powerCUTLASS 3.8 + CUDA 12.8 Blackwell SM100 GEMM Performance16384x17920x16384 shape matmul on uniform random ints range [-4,4] \ No newline at end of file +0102030405060708090100NVFP4MXFP4MXFP8S8F8F16BF16TF32% throughput of theoretical peak @ powerCUTLASS 3.8 + CUDA 12.8 Blackwell SM100 GEMM Performance16384x17920x16384 shape matmul on uniform random ints range [-4,4] diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py new file mode 100755 index 00000000..96733d60 --- /dev/null +++ b/python/cutlass_library/emit_kernel_listing.py @@ -0,0 +1,834 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# +# +# \brief Generates the CUTLASS kernel listing with kernel filtering +# + +# + +############################################################################### +# Example usage: +# generator.py --operations all --generator-target kernel_listing \ +# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports +############################################################################### + +import collections +import csv +import json +import math +import os + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +audit_csv_fields = [ + "KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD", + "Layout_A", "Layout_B", "Layout_C", "Layout_D", + "Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D", + "1SM/2SM", + "StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types", + "Test Counts" +] + +audit_csv_runtime_fields = [ + "KerneIndex", "KernelName", + "Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K", + "Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K", + "M", "N", "K", "L", "Alpha_val", "Beta_val", + "Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled" +] + +def hash_cutlass_string(input_string): + # Regex pattern to match instruction shape + instruction_shape_pattern = r"[a-zA-Z]\d+x\d+x\d+" # Matches '_s128x128x64', '_h64x128x16', etc. + mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + + # Remove instruction shape (e.g., '_s128x128x64', '_h64x128x16') + output = re.sub(instruction_shape_pattern, "", input_string) + + # Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + output = re.sub(mma_cluster_shape_pattern, "", output) + + return output + +def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b): + # Define a dictionary mapping the detected types to runtime values + datatype_map = { + '_f4_': '_' + runtime_datatype_a + '_', + '_f6_': '_' + runtime_datatype_b + '_', + '_f8_': '_' + runtime_datatype_a + '_', + } + + # Use regex to identify and replace _f4_, _f6_, or _f8_ in the kernel name + def substitute(match): + datatype = match.group(0) # This is the matched "_f4_", "_f6_", or "_f8_" + return datatype_map.get(datatype, datatype) # Replace or leave as is + + # Regex to find "_f4_", "_f6_", or "_f8_" in the hashed_kernel_name + updated_kernel_name = re.sub(r'_f4_|_f6_|_f8_', substitute, hashed_kernel_name) + + return updated_kernel_name + +# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k. +def get_kernel_features(operation, kernel_name, + dynamic_datatype, runtime_input_datatype): + numcta_inst = "2sm" if "2sm" in kernel_name else "1sm" + math_inst = operation.tile_description.math_instruction + + if dynamic_datatype: + dtype_name_A = runtime_input_datatype[0] + dtype_name_B = runtime_input_datatype[1] + else: + dtype_name_A = DataTypeNames[operation.A.element] + dtype_name_B = DataTypeNames[operation.B.element] + + layout_name_A = ShortLayoutTypeNames[operation.A.layout] + layout_name_B = ShortLayoutTypeNames[operation.B.layout] + layout_name_C = ShortLayoutTypeNames[operation.C.layout] + layout_name_D = ShortLayoutTypeNames[operation.D.layout] + + scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void + scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void) + audit_vals = [ + "BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM", + kernel_name, + dtype_name_A, + dtype_name_B, + DataTypeNames[operation.C.element], + DataTypeNames[operation.tile_description.math_instruction.element_accumulator], + DataTypeNames[operation.element_epilogue], + DataTypeNames[operation.D.element], + DataTypeNames[scale_factor_D_type], + DataTypeNames[scale_factor_A_type], + layout_name_A, + layout_name_B, + layout_name_C, + layout_name_D, + str(operation.A.alignment), + str(operation.B.alignment), + str(operation.C.alignment), + str(operation.D.alignment), + numcta_inst, + "Y" if 'stream_k' in kernel_name else "N", + ] + return audit_vals + +# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta. +def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster): + math_inst = operation.tile_description.math_instruction + audit_vals = [ + str(math_inst.instruction_shape[0]), + str(math_inst.instruction_shape[1]), + str(math_inst.instruction_shape[2]), + str(operation.tile_description.threadblock_shape[0]), + str(operation.tile_description.threadblock_shape[1]), + str(operation.tile_description.threadblock_shape[2]), + str(operation.tile_description.cluster_shape[0]), + str(operation.tile_description.cluster_shape[1]), + str(operation.tile_description.cluster_shape[2]), + str(cluster_shape[0]), + str(cluster_shape[1]), + str(cluster_shape[2]), + str(fallback_cluster_shape[0]), + str(fallback_cluster_shape[1]), + str(fallback_cluster_shape[2]), + str(problem_shape[0]), + str(problem_shape[1]), + str(problem_shape[2]), + str(problem_shape[3]), + str(alpha), + str(beta), + "Y" if dynamic_datatype else "N", + "Y" if dynamic_cluster else "N", + ] + return audit_vals + + +def _getSubOperationType(kernel): + + if kernel.operation_kind == OperationKind.Gemm: + return GemmKindNames[kernel.gemm_kind] + elif kernel.operation_kind == OperationKind.Conv2d: + return "conv_" + ConvKindNames[kernel.conv_kind] + elif kernel.operation_kind == OperationKind.Syrk: + return "syrk_" + SyrkKindNames[kernel.syrk_kind] + elif kernel.operation_kind == OperationKind.Trmm: + return "trmm_" + TrmmKindNames[kernel.trmm_kind] + elif kernel.operation_kind == OperationKind.Symm: + return "symm_" + SymmKindNames[kernel.symm_kind] + else: + raise Exception("Unsupported kernel type") + +def _get_inst_shape(math_instruction): + return "".join(str(x) for x in math_instruction.instruction_shape) + +def _is_simt_inst(math_instruction): + return _get_inst_shape(math_instruction) in ["111","114"] + +def _getInstType(input_precision, accumulate_precision, math_instruction): + + # inst_shape + inst_shape = _get_inst_shape(math_instruction) + + # input precision + if input_precision == "fp32" and inst_shape != "111": + inp = "tf32" + else: + inp = input_precision + + # Handle SIMT op types first + if _is_simt_inst(math_instruction): + + simt_input_precision_to_inst = { + "fp32": "FFMA", + "fp64": "DFMA", + "fp16": "HFMA", + "int8": "IDP4A", + } + inst = simt_input_precision_to_inst[input_precision] + + else: # Tensor op instructions + + if accumulate_precision == "cf64": + fp64_acc_map = { + MathOperation.multiply_add_complex_gaussian : "gz", + MathOperation.multiply_add_complex : "z", + } + acc = fp64_acc_map[math_instruction.math_operation] + else: + tensor_op_acc_map = { + "fp32" : "s", + "cf32" : "s", + "fp16" : "h", + "int32": "i", + "fp64" : "d", + } + acc = tensor_op_acc_map[accumulate_precision] + + inst = "{}{}{}".format(acc, inst_shape, inp) + + return inst +# TODO: Computes FLOps/Bytes for GEMM - revisit for conv +def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0): + + # TODO: adjust for sparsity + gmem_bytes = ( + (DataTypeSize[operation.A.element] * m // 8) * k + + (DataTypeSize[operation.B.element] * n // 8) * k + + (DataTypeSize[operation.C.element] * m // 8) * n + ) + + # TODO: complex-valued support + flops = 2 * (m * n * k) + + if bool(beta): + gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n + flops += 2 * m * n + + gmem_bytes *= batch_count + flops *= batch_count + + return flops / gmem_bytes + +def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode + ): + profiler_reference_computing = "--verification-providers=device --providers=cutlass" + # beta values for L0 and L1 + # TODO: randomize beta values for wider coverage + beta_values = [0.5] + + is_supported_arch = (arch in ["100a"]) + + is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch + + if (mode == "functional_L0") and is_supported_arch: + problem_waves = [0.5, 1.25, 2.5] + + # + # Dense Gemm + # + + sm100_mma_data_type_general = [ + 'x16gemm_f16_f16_f16_f16_f16', + 'x16gemm_f16_f16_f16_void_f16', + 'x16gemm_f16_f16_f32_f16_f16', + 'x8tf32gemm_f32_f32_f32_f32_f32', + 'x16bf16gemm_f32_f32_f32_f32_f32', + ] + + sm100_mma_data_type_runtime_dtype = [ + 'x32gemm_f4_f4_f32_f32_f32', + 'x32gemm_f6_f6_f32_f32_f32', + 'x32gemm_f8_f8_f32_f32_f32', + ] + + sm100_mma_data_type_mergeable = [ + 'x32gemm_e4m3_e4m3_f32_f32_f32',# mask out one instance for verification + 'x32gemm_e2m1_e2m1_f32_f32_f32', + 'x32gemm_e3m2_e3m2_f32_f32_f32', + ] + + sm100_mma_cluster_size = [ + '8x1x1', + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + # Restrict to two layouts to reduce L0 build and test time. + sm100_mma_layouts = [ + 'tnt', + 'ntn' + ] + + sm100_mma_instruction_shape = [ + # [0] .1CTA, General + ['64x128', '128x128', '128x256'], + # [1] .2CTA, General + ['128x128', '256x128', '256x256'], + ] + + # regex list must be in kernel procedural name order + mergeable_sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + mergeable_sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_mergeable, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + # + # Block Scale Gemm + # + + block_scaled_data_type_base = [ + # runtime datatypes + 'x32gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'x32gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2', + 'x64gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + 'x32gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', + ] + + block_scaled_data_type_mergeable = [ + 'x32gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'x32gemm.*ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2', + 'x64gemm.*ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + 'x32gemm.*ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', + ] + + block_scaled_data_type = block_scaled_data_type_base + block_scaled_data_type_mergeable + + block_scaled_cluster_size = [ + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + block_scaled_layouts = ['tnt'] + block_scaled_instruction_shape = [ + # .1CTA + ['128x128', '128x192', '128x256'], + # .2CTA + ['256x128', '256x192', '256x256'], + ] + # regex list must be in kernel procedural name order + mergeable_block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + mergeable_block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type_mergeable, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + if arch == "100a": + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" + else: + error_message = "unsupported arch, only support sm100a" + raise Exception(error_message) + + # Statically encoded kernels are still added to generated_kernels + # but are filtered out from the testing commands to reduce test duration. + # The mergeable_kernel_filter specifies the kernels that are already covered + # by the runtime datatype tests so that we safely mark them off + # without changing the test coverage. + mergeable_kernel_filter = f"({mergeable_sm100_mma_filter_regex_1sm})|" \ + f"({mergeable_sm100_mma_filter_regex_2sm})|" \ + f"({mergeable_block_scaled_filter_regex_1sm})|" \ + f"({mergeable_block_scaled_filter_regex_2sm})" + elif mode == "functional_L1": + + sm100_mma_cluster_size = [ + '0x0x1' # dynamic cluster + ] + # Restrict to two layouts to reduce L1 build and test time. + sm100_mma_layouts = ['tnt', 'ntn'] + sm100_mma_instruction_shape = [ + # .1CTA + ['64x128', '128x128', '128x256'], + # .2CTA + ['128x128', '256x128', '256x256'] + ] + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[0], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_instruction_shape[1], sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2', + 'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + 'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', + ] + + block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1'] + block_scaled_layouts = ['tnt'] + block_scaled_instruction_shape = [ + # .1CTA + ['128x128', '128x192', '128x256'], + # .2CTA + ['256x128', '256x192', '256x256'], + ] + # regex list must be in kernel procedural name order + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[0], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_instruction_shape[1], block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})|" + # CTA tiles for super MMA - only run one tile size to reduce build/test times + supermma_kernel_cta_tiles = [ + # h1688, s1688, i16832, i8816 + [ '256x128' ], + # d884, c1688, + [ '128x128' ], + # c1688, z884 + [ '128x64' ], + # gz884 + [ '64x64' ] + ] + + # super MMA instruction shapes, planar complex type excluded as they are not required + supermma_instruction_shapes = [ + [ 'h1688gemm_(?!planar_complex)', + 's1688gemm_f16', + 's1688gemm_bf16', + 's1688gemm_tf32', + 'i16832gemm', + 'i8816gemm' ], + [ 'd884gemm', 'c1688tf32gemm' ] , + [ 'c1688gemm', + 'z884gemm' ], + [ 'gz884gemm'] + ] + + # It's not pretty, but not sure why different instructions support different tile sizes. + filter_regex_supermma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[0], supermma_kernel_cta_tiles[0]]]) + ").*" + filter_regex_supermma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[1], supermma_kernel_cta_tiles[1]]]) + ").*" + filter_regex_supermma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[2], supermma_kernel_cta_tiles[2]]]) + ").*" + filter_regex_supermma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [supermma_instruction_shapes[3], supermma_kernel_cta_tiles[3]]]) + ").*" + + filter_regex_supermma = f"({filter_regex_supermma_0})|({filter_regex_supermma_1})|({filter_regex_supermma_2})|({filter_regex_supermma_3})" + + problem_waves = [0.5, 1.25, 2.5] + + kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_supermma})" + else: + raise ValueError() + + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + + audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv") + + audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv") + + if is_runtime_datatype_enabled: + mergeable_kernel_filter_re = re.compile(mergeable_kernel_filter) + kernel_filter_re = re.compile(kernel_filter) + testcase_counter = 0 + kernels_emitted = 0 + kernels_total = 0 + + perf_json_list = [] + kernel_name_set = set() + + testlist_csv_fields = ["testcase", "metadata"] + testlist_csv_rows = [] + auditlist_csv_map = {} + auditlist_csv_params_map = {} + + kernel_features = {} + + for cc in manifest.operations[OperationKind.Gemm].keys(): + for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items(): + assert(len(operation_l) == 1) + kernels_total += 1 + if len(kernel_filter_re.findall(kernel_name)) == 0: + continue + # Only test f16 I/O void C kernels in void C kernel set + # Exception: Use void C kernels for more accurate perf testing + if '_void_' in kernel_name and 'perf_' not in mode: + if 'f16_f16_f16_void_f16' not in kernel_name : + continue + + # Filter out the statically encoded tests which are + # covered by runtime datatype tests to avoid repetition. + if is_runtime_datatype_enabled and len(mergeable_kernel_filter_re.findall(kernel_name)) != 0: + continue + + + kernels_emitted += 1 + kernel_name_set.add(kernel_name) + hashed_kernel_name = hash_cutlass_string(kernel_name) + operation = operation_l[0] + + dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0 + or operation.tile_description.cluster_shape[1] == 0) + + dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name + + runtime_input_datatypes = [None] + + if dynamic_datatype: + if "f4_f4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "f4_f6" in kernel_name: + runtime_input_datatypes = [['e2m1','e3m2']] + elif "f4_f8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "f6_f4" in kernel_name: + runtime_input_datatypes = [['e3m2','e2m1']] + elif "f6_f6" in kernel_name: + runtime_input_datatypes = [['e3m2','e3m2']] + elif "f6_f8" in kernel_name: + runtime_input_datatypes = [['e3m2','e4m3']] + + elif "f8_f4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "f8_f6" in kernel_name: + runtime_input_datatypes = [['e4m3','e3m2']] + elif "f8_f8" in kernel_name: + runtime_input_datatypes = [ + # mask out those not covered in statically encoded test cases + # ['e5m2','e4m3'], + # ['e4m3','e5m2'], + ['e4m3','e4m3'] + ] + + # block scaled kernels + elif "ue8m0xf4_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue4m3xf4_ue4m3xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue8m0xf4_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m3']] + elif "ue8m0xf4_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "ue8m0xf6_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m1']] + elif "ue8m0xf6_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m3']] + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "ue8m0xf8_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m3']] + elif "ue8m0xf8_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e4m3','e4m3']] + + if dynamic_cluster: + if mode == "functional_L0": + runtime_cluster_shapes = [[1,1,1], [2,1,1], [2,2,1], [4,1,1], [4,4,1]] + else: + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]] + cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape + else: + runtime_cluster_shapes = [operation.tile_description.cluster_shape] + cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0]) + cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1]) + cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2]) + + alignment_a = operation.A.alignment + alignment_b = operation.B.alignment + alignment_c = operation.C.alignment + alignment_ab_max = max(alignment_a, alignment_b) + + layout3x = operation.layout_name_3x() + data_types = operation.datatype_name_3x() + + ctas_per_mma_instruction = 1 + if '_2sm' in kernel_name: + ctas_per_mma_instruction = 2 + valid_cluster_shapes = [] + + # Remove any cluster shapes that have cluster_m that is not divisible by 2 + for cs in runtime_cluster_shapes: + if cs[0] % 2 == 0: + valid_cluster_shapes.append(cs) + runtime_cluster_shapes = valid_cluster_shapes + + kernel_problem_waves = problem_waves + if mode == "functional_L0" or mode == "functional_L1": + # for functional testing, we want to perturb just a little from even shapes + # large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not + # -16 ensures that we are TMA aligned even for FP8/Int8 + min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max + max_k = (cta_tile_shape_k*8) - alignment_ab_max + problem_shapes_k = [min_k, max_k] + sm_count = 16 + # Larger k and less than half wave trigger streamk +separate reduction case to be generated + if 'stream_k' in kernel_name: + problem_shapes_k = [max_k, cta_tile_shape_k*32] + kernel_problem_waves = [0.125, 1.25, 2.5] + else: + raise ValueError + + if "void" in kernel_name: + beta_values = [0] + + alignment_shift_m = max(alignment_c, alignment_a) + alignment_shift_n = max(alignment_c, alignment_b) + + is_first_line = True + for index_waves, waves in enumerate(kernel_problem_waves): + for index_k, k in enumerate(problem_shapes_k): + for beta in beta_values: + for cluster_shape in runtime_cluster_shapes: + for runtime_input_datatype in runtime_input_datatypes: + grid_size = waves * sm_count + cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape) + if cluster_shape_m >= cluster_shape_n: + grid_m = cluster_shape_m + grid_n = grid_size / grid_m + grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1) + else: + grid_n = cluster_shape_n + grid_m = grid_size / grid_n + grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1) + + verification_required = False + if mode == "functional_L0" or mode == "functional_L1": + if '_void_' not in kernel_name: + verification_required = True + + m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max) + n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max) + k = int(k) + + # For functional testing, we want to perturb just a little from even shapes. + # Only do this if the perturbation does not cause one of the dimensions of the + # problem size to go to zero. This can occur for blockscaling kernels for which + # the alignment requirements for A and B can be quite large (e.g., 256). + if m > alignment_shift_m: + m -= alignment_shift_m + if n > alignment_shift_n: + n -= alignment_shift_n + + if '_n32t32_' in kernel_name: + continue + batch_count = 1 + if mode == "functional_L0" or mode == "functional_L1" : + if index_waves == 0 and index_k == 0 : + batch_count = 3 if mode == "functional_L0" else 5 + gemm_op = "gemm" + + profiler_reference_computing_override = profiler_reference_computing + if "bstensorop" in kernel_name: + profiler_reference_computing_override = "--mode=trace" + gemm_op = "block_scaled_gemm" + + problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)] + + assert m > 0 and n > 0 and k > 0 + + # Emit per-testcase metadata for perf testing usage, eventually in perf database + metadata_dict = { + "input_params": { + 'problem_size_category' : problem_size_category, + 'operation' : _getSubOperationType(operation), + 'datatype' : data_types, + 'layout' : layout3x, + 'm' : m, + 'n' : n, + 'k' : k, + 'beta' : beta, + 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta) + }, + "runtime_params": { + 'ctas_per_mma_instruction' : ctas_per_mma_instruction, + 'tilesize_m' : cta_tile_shape_m, + 'tilesize_n' : cta_tile_shape_n, + 'tilesize_k' : cta_tile_shape_k, + 'cluster_shape_m' : cluster_shape_m, + 'cluster_shape_n' : cluster_shape_n, + } + } + + cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m + cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n + cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k + + + if dynamic_datatype: + runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype) + metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a + metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b + + testcase_metadata = [ + f"cutlass_profiler --operation={gemm_op} {profiler_reference_computing_override} --error-on-no-match --error-if-nothing-is-profiled" + + f" --kernels={kernel_name}" + + f" --m={str(m)}" + + f" --n={str(n)}" + + f" --k={str(k)}" + + f" --cluster_m={str(cluster_shape_m)}" + + f" --cluster_n={str(cluster_shape_n)}" + + f" --cluster_k={str(cluster_shape_k)}" + + f" --cluster_m_fallback={str(cluster_m_fallback)}" + + f" --cluster_n_fallback={str(cluster_n_fallback)}" + + f" --cluster_k_fallback={str(cluster_k_fallback)}" + + f" --beta={str(beta)}" + + f" --batch_count={str(batch_count)}" + + f" --verification-required={str(verification_required).lower()}" + ] \ + + output_dynamic_datatype = dynamic_datatype + if output_dynamic_datatype: + testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" + + f" --runtime_input_datatype_b={runtime_datatype_b}") + + testcase_metadata.append(json.dumps(metadata_dict)) + testlist_csv_rows.append(testcase_metadata) + testcase_counter += 1 + + alpha = 1.0 + + if dynamic_datatype: + hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b) + + # If kernel_name is new, initialize its feature set with defaults + if hashed_kernel_name not in kernel_features: + kernel_features[hashed_kernel_name] = { + "is_support_dynamic_cluster": False, + "is_support_dynamic_datatype": False, + } + + # Update features for the hashed kernel name + kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster + kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype + + if hashed_kernel_name not in auditlist_csv_params_map: + auditlist_csv_params_map[hashed_kernel_name] = [] + + audit_row_params = get_kernel_params( + operation, + hashed_kernel_name, + (cluster_shape_m, cluster_shape_n, cluster_shape_k), + (cluster_m_fallback, cluster_n_fallback, cluster_k_fallback), + (m, n, k, batch_count), + alpha, beta, + dynamic_datatype, dynamic_cluster + ) + + auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params) + + if hashed_kernel_name not in auditlist_csv_map: + audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype) + auditlist_csv_map[hashed_kernel_name] = audit_row + + with open(outfile_name, 'w') as testlist_csv: + csv_writer = csv.writer(testlist_csv, delimiter=',') + csv_writer.writerow(testlist_csv_fields) + csv_writer.writerows(testlist_csv_rows) + + with open(audit_file_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_fields) + for hashed_kernel_name, row in auditlist_csv_map.items(): + # Append the dynamic features as "Y" or "N" + dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N" + dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N" + test_count = len(auditlist_csv_params_map[hashed_kernel_name]) + csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count]) + + with open(audit_file_params_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_runtime_fields) + for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1): + for i, row in enumerate(rows): + if i == 0: + csv_writer.writerow([kernel_index, hashed_kernel_name] + row) + else: + csv_writer.writerow(["", ""] + row) + + print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.") + + # Generate a newline separated list of kernel filters + assert(len(kernel_name_set) == kernels_emitted) + output_filter_enabled = True + if output_filter_enabled: + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name, "w") as file: + kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set)) + for kernel_name in kernel_name_set: + file.write(kernel_name + "\n") + + # Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and superMMA kernels in cutlass2.x generated together. + if mode == "functional_L0" or mode == "functional_L1": + # Sort the .csv file + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + with open(outfile_name) as file: + data = file.readlines() + data.sort() + with open(outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + # Sort the kernel list + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name) as file: + data = file.readlines() + data.sort() + with open(kernel_filter_outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 1e944a08..5cc4f8b4 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -1,4 +1,4 @@ -################################################################################################# + # # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause @@ -64,7 +64,7 @@ class GemmOperation: def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, - tile_scheduler = TileSchedulerType.Default + tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False , ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None @@ -74,6 +74,7 @@ class GemmOperation: GemmKind.Universal3x, GemmKind.SparseUniversal3x, GemmKind.BlockScaledUniversal3x, + GemmKind.GroupedGemmUniversal3x, } self.is_3x = gemm_kind in kinds_3x self.prefix = "3x" if self.is_3x else "" @@ -111,6 +112,12 @@ class GemmOperation: self.swizzling_functor = swizzling_functor self.tile_scheduler = tile_scheduler + # Only enable mixed input mode and mixed input shuffle for Hopper + self.mixed_input_mode = None + if self.is_mixed_input() and self.arch >= 90 and self.arch < 100: + self.mixed_input_mode = mixed_input_mode + self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle + # def is_complex(self): complex_operators = [ @@ -211,6 +218,18 @@ class GemmOperation: return extended_name + # + def mixed_input_mode_name(self): + mode_name_mapping = { + MixedInputMode.ConvertOnly: "_cvt", + MixedInputMode.ScaleOnly: "_scl", + MixedInputMode.ScaleWithZeroPoint: "_sclzr" + } + mode_name = mode_name_mapping.get(self.mixed_input_mode, "") + if self.mixed_input_shuffle: + mode_name = mode_name + "_shfl" + return mode_name + def extended_name_3x(self): '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( @@ -237,6 +256,8 @@ class GemmOperation: element_d = d_type_names, core_name = self.core_name()) + if self.mixed_input_mode != None: + extended_name = extended_name + self.mixed_input_mode_name() return extended_name def datatype_name_3x(self): @@ -768,6 +789,8 @@ using ${operation_name}_epilogue = ${epilogue_functor} >::CollectiveOp; +${mixed_dtype_prepare_code} + using ${operation_name}_mainloop = typename cutlass::gemm::collective::CollectiveBuilder< ${arch}, ${opcode_class_main}, @@ -782,7 +805,7 @@ using ${operation_name}_mainloop = // Gemm operator ${operation_name} using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, + ${problem_shape}, ${operation_name}_mainloop, ${operation_name}_epilogue, ${tile_scheduler}>; @@ -830,7 +853,18 @@ ${compile_guard_end} return SubstituteTemplate(block_scaled_template, block_scaled_values) - # + @staticmethod + def pointerize_if_grouped(operation, layout): + return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* " + + @staticmethod + def problem_shape(operation): + gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" + + return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type + def emit(self, operation): _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") _LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name()) @@ -926,17 +960,83 @@ ${compile_guard_end} element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' + operation_name_str = operation.procedural_name() + layout_a_str = LayoutTag[instance_layout_A] + layout_b_str = LayoutTag[instance_layout_B] + mixed_dtype_prepare_code = "" + if operation.mixed_input_mode != None: + A_dtype = operation.A.element + B_dtype = operation.B.element + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + narrow_tag = DataTypeTag[narrow_dtype] + wide_tag = DataTypeTag[wide_dtype] + scale_tag = DataTypeTag[wide_dtype] + zero_tag = DataTypeTag[wide_dtype] + + do_shuffle = False + value_shuffle_str = "" + if narrow_dtype_bits == 4 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + if narrow_dtype_bits == 8 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + do_shuffle = operation.mixed_input_shuffle and do_shuffle + + if do_shuffle: + if is_A_dtype_narrow: + stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" + layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" + else: + stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" + layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" + # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and + # layout_{a, b}_str are to prevent errors in Windows platform unity build + mixed_dtype_prepare_code = f""" +using {operation_name_str}_StrideNarrow = {stride_narrow_str}; +using {operation_name_str}_ValueShuffle = {value_shuffle_str}; +static constexpr int {operation_name_str}_NumShuffleAtoms = 1; +using {operation_name_str}_MmaAtomShape = cute::Layout>>; +using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>()); +using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout, {operation_name_str}_StrideNarrow>{{}})); + """ + + mixed_input_modes_to_element = { + MixedInputMode.ConvertOnly: narrow_tag, + MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", + MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>" + } + narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag) + + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" + + if is_A_dtype_narrow: + element_a = narrow_element + else: + element_b = narrow_element + values = { - 'operation_name': operation.procedural_name(), + 'operation_name': operation_name_str, 'operation_suffix': self.operation_suffix, + 'problem_shape': self.problem_shape(operation), 'element_a': element_a, - 'layout_a': LayoutTag[instance_layout_A], + 'layout_a': self.pointerize_if_grouped(operation, layout_a_str), 'element_b': element_b, - 'layout_b': LayoutTag[instance_layout_B], + 'layout_b': self.pointerize_if_grouped(operation, layout_b_str), 'element_c': DataTypeTag[operation.C.element], - 'layout_c': LayoutTag[instance_layout_C], + 'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]), 'element_d': DataTypeTag[operation.D.element], - 'layout_d': LayoutTag[instance_layout_D], + 'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]), 'element_accumulator': DataTypeTag[operation.accumulator_type()], 'opcode_class_main': OpcodeClassTag[opcode_class_main], 'opcode_class_epi': OpcodeClassTag[opcode_class_epi], @@ -968,6 +1068,7 @@ ${compile_guard_end} 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]), + 'mixed_dtype_prepare_code': mixed_dtype_prepare_code } return SubstituteTemplate(self.gemm_template, values) @@ -1294,7 +1395,8 @@ class EmitGemmConfigurationLibrary: GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance, GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, - GemmKind.Grouped: EmitGemmGroupedInstance + GemmKind.Grouped: EmitGemmGroupedInstance, + GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance, } self.gemm_kind_wrappers = { @@ -1306,7 +1408,8 @@ class EmitGemmConfigurationLibrary: GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation', GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', - GemmKind.Grouped: 'GemmGroupedOperation' + GemmKind.Grouped: 'GemmGroupedOperation', + GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation' } self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" @@ -1363,6 +1466,7 @@ void initialize_${configuration_name}(Manifest &manifest) { ("library_internal.h", None), ("gemm_operation.h", None), ("gemm_operation_3x.hpp", None), + ("grouped_gemm_operation_3x.hpp", None), ("sparse_gemm_operation_3x.hpp", None), ("block_scaled_gemm_operation_3x.hpp", None), ("cutlass/arch/wmma.h", None), diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index c75f3342..a4cf5f90 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -90,9 +90,11 @@ try: raise ImportError("Disabling attempt to import cutlass_library") from cutlass_library.library import * from cutlass_library.manifest import * + from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist except ImportError: from library import * from manifest import * + from emit_kernel_listing import emit_gemm_kernel_testlist ################################################################################################### # @@ -177,7 +179,8 @@ def CreateGemmUniversal3xOperator( complex_transforms=None, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity1, - tile_schedulers=[TileSchedulerType.Default]): + tile_schedulers=[TileSchedulerType.Default], + gemm_kind=GemmKind.Universal3x): if type(data_types) is dict: data_types = [data_types] @@ -206,7 +209,6 @@ def CreateGemmUniversal3xOperator( D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) gemm_op_extra_args = {} - gemm_kind = GemmKind.Universal3x element_compute = data_type.get("epi_type", data_type["acc_type"]) @@ -218,16 +220,43 @@ def CreateGemmUniversal3xOperator( gemm_kind = GemmKind.BlockScaledUniversal3x - operation = GemmOperation( - gemm_kind, tile_description.minimum_compute_capability, - tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, - kernel_schedule, epilogue_schedule, tile_scheduler, **gemm_op_extra_args) + A_dtype = data_type["a_type"] + B_dtype = data_type["b_type"] + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) - manifest.append(operation) - operations.append(operation) + mixed_input_modes = [None] + if narrow_dtype_bits != wide_dtype_bits: + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + mixed_input_modes = [MixedInputMode.ScaleOnly] + else: + mixed_input_modes = [MixedInputMode.ConvertOnly, MixedInputMode.ScaleOnly, MixedInputMode.ScaleWithZeroPoint] + + mixed_input_shuffle_options = [False] + if (mixed_input_modes[0] is not None) and (wide_dtype_bits == 16) and (narrow_dtype_bits == 4 or narrow_dtype_bits == 8): + mixed_input_shuffle_options = [False, True] + + for mixed_input_mode, mixed_input_shuffle in product(mixed_input_modes, mixed_input_shuffle_options): + operation = GemmOperation( + gemm_kind, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler, + mixed_input_mode=mixed_input_mode, mixed_input_shuffle=mixed_input_shuffle, **gemm_op_extra_args) + manifest.append(operation) + operations.append(operation) return operations +def is_grouped(gemm_kind): + return gemm_kind == GemmKind.GroupedGemmUniversal3x + # Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts def CreateSparseGemmUniversal3xOperator( manifest, layouts, tile_descriptions, data_types, @@ -4934,12 +4963,7 @@ def GenerateSM80(manifest, cuda_version): ################################################################################################### -def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): - if ( - not CudaToolkitVersionSatisfies(cuda_version, 12, 4) - ): - return - +def GenerateSM89_TensorOp_16832_fp8(manifest, element_acc): layouts = [ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) @@ -4948,49 +4972,48 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): math_instructions = [ MathInstruction( [16, 8, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, + DataType.e4m3, DataType.e4m3, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add), MathInstruction( [16, 8, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, + DataType.e4m3, DataType.e5m2, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add), MathInstruction( [16, 8, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, + DataType.e5m2, DataType.e4m3, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add), MathInstruction( [16, 8, 32], - DataType.e5m2, DataType.e5m2, DataType.f32, + DataType.e5m2, DataType.e5m2, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add), MathInstruction( [16, 8, 32], - DataType.e4m3, DataType.e4m3, DataType.f32, + DataType.e4m3, DataType.e4m3, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add_fast_accum), MathInstruction( [16, 8, 32], - DataType.e4m3, DataType.e5m2, DataType.f32, + DataType.e4m3, DataType.e5m2, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add_fast_accum), MathInstruction( [16, 8, 32], - DataType.e5m2, DataType.e4m3, DataType.f32, + DataType.e5m2, DataType.e4m3, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add_fast_accum), MathInstruction( [16, 8, 32], - DataType.e5m2, DataType.e5m2, DataType.f32, + DataType.e5m2, DataType.e5m2, element_acc, OpcodeClass.TensorOp, MathOperation.multiply_add_fast_accum), ] min_cc = 89 - max_cc = 89 - + max_cc = 100 alignment_constraints = [16,] alignment_constraints_small_channels = [16, 8, 4] @@ -5077,6 +5100,18 @@ def GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version): else: op.C.alignment = 8 +def GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 4): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f32) + +def GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f16) + # def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version): @@ -5177,7 +5212,8 @@ def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version): # def GenerateSM89(manifest, cuda_version): - GenerateSM89_TensorOp_16832_fp8(manifest, cuda_version) + GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version) + GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version) GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version) ################################################################################################### @@ -5189,6 +5225,7 @@ try: generate_tf32_math_instructions_sm90, generate_int8_math_instructions_sm90, generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, make_sparse_math_instructions, generate_tile_descriptions_sm90, get_valid_schedules, @@ -5201,6 +5238,7 @@ except ImportError: generate_tf32_math_instructions_sm90, generate_int8_math_instructions_sm90, generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, make_sparse_math_instructions, generate_tile_descriptions_sm90, get_valid_schedules, @@ -5208,8 +5246,8 @@ except ImportError: fix_alignments, ) -def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): +def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) @@ -5262,10 +5300,11 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): data_types=data_type, instantiation_level=instantiation_level, layout=layout, + gemm_kind=gemm_kind, ) if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) if len(stream_k_schedules): assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, @@ -5728,8 +5767,8 @@ def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version): tile_schedulers=[TileSchedulerType.StreamK]) -def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): - if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): +def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): return instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) @@ -5783,10 +5822,11 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): data_types=data_type, instantiation_level=instantiation_level, layout=layout, + gemm_kind=gemm_kind, ) if len(schedules): - CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) if len(stream_k_schedules): assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, @@ -5851,6 +5891,90 @@ def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) +def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 1): + return + + instantiation_level = manifest.get_sm90_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) + is_aligned = True + + # layouts for ABC, their alignments will be fixed later based on the data type + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + valid_types_for_a_b_acc = [ + (DataType.e4m3, DataType.f16, DataType.f32), + (DataType.e4m3, DataType.bf16, DataType.f32), + (DataType.e5m2, DataType.f16, DataType.f32), + (DataType.e5m2, DataType.bf16, DataType.f32), + (DataType.s8, DataType.f16, DataType.f32), + (DataType.s8, DataType.bf16, DataType.f32), + (DataType.u8, DataType.f16, DataType.f32), + (DataType.u8, DataType.bf16, DataType.f32), + (DataType.s4, DataType.f16, DataType.f32), + (DataType.s4, DataType.bf16, DataType.f32), + (DataType.s4, DataType.e4m3, DataType.f32), + (DataType.s4, DataType.e5m2, DataType.f32), + (DataType.u4, DataType.f16, DataType.f32), + (DataType.u4, DataType.bf16, DataType.f32), + (DataType.u2, DataType.f16, DataType.f32), + (DataType.u2, DataType.bf16, DataType.f32), + (DataType.s2, DataType.f16, DataType.f32), + (DataType.s2, DataType.bf16, DataType.f32), + ] + # Note: For sizeof(a_type) > sizeof(b_type), some generated kernels might crash due to a compiler bug. Disable it for now. + #swapped_valid_types_for_a_b_acc = [(b_type, a_type, acc_type) for a_type, b_type, acc_type in valid_types_for_a_b_acc] + #valid_types_for_a_b_acc = valid_types_for_a_b_acc + swapped_valid_types_for_a_b_acc + + math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc) + + valid_types_for_d = [DataType.f32] + valid_types_for_c = [DataType.f32] + + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Fix alignments, DataTypeSize are in the unit of bits + alignment_bits = 128 + layout[0][1] = alignment_bits // DataTypeSize[data_type['a_type']] + layout[1][1] = alignment_bits // DataTypeSize[data_type['b_type']] + layout[2][1] = alignment_bits // DataTypeSize[data_type['c_type']] + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): @@ -6662,7 +6786,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) -def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): +def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return @@ -6680,6 +6804,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): min_cc = 100 max_cc = 100 + grouped = is_grouped(gemm_kind) + math_instructions_1sm = [ # f16 -> f16 #MathInstruction( @@ -6736,6 +6862,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): MathOperation.multiply_add)] cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1],[4,4,1] + , DynamicClusterShape ] tile_schedulers = [ @@ -6776,9 +6903,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): for layout in layouts: layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100 + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # for mixed precision kernels, also generate kernels that write output matrix in the A/B format # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) @@ -6806,8 +6935,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # 2xSM MMA kernels math_instructions_2sm = [ @@ -6886,6 +7015,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): MathOperation.multiply_add)] cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -6921,13 +7051,16 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): for layout in layouts: layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] - if math_inst.instruction_shape[0] == 128: + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm else: epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100 CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # for mixed precision kernels, also generate kernels that write output matrix in the A/B format # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) @@ -6955,9 +7088,9 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version): layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) -def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version): +def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return @@ -6976,6 +7109,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version): min_cc = 100 max_cc = 100 epi_type = DataType.f32 + grouped = is_grouped(gemm_kind) math_instructions_1sm = [ # inst 64x128 @@ -7038,6 +7172,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version): MathOperation.multiply_add)] cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] + , DynamicClusterShape ] tile_schedulers = [ @@ -7163,9 +7298,14 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version): if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ ( data_type["d_type"] == DataType.e5m2 ): continue + # don't support runtime data type for grouped yet + if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8): + continue + kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100 + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], - tile_schedulers=tile_schedulers) + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) # 2xSM MMA kernels math_instructions_2sm = [ @@ -7241,6 +7381,7 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version): ] cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -7361,15 +7502,20 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version): if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ ( data_type["d_type"] == DataType.e5m2 ): continue + # don't support runtime data type for grouped yet + if grouped and (data_type["a_type"] == DataType.f8 or data_type["b_type"] == DataType.f8): + continue - if math_inst.instruction_shape[0] == 128: + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm else: epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = KernelScheduleType.TmaWarpSpecialized2SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100 CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) - + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): @@ -7460,6 +7606,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud [2,1,1], # [1,4,1], [4,4,1] + , DynamicClusterShape ] # 1xSM MMA kernels @@ -7533,6 +7680,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud [4,1,1], # [4,2,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -7728,6 +7876,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio [2,1,1], # [1,4,1], [4,4,1] + , DynamicClusterShape ] # 1xSM MMA kernels @@ -7841,6 +7990,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio [4,1,1], # [4,2,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -8419,6 +8569,9 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) + # grouped GEMM + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) # # Block Scaled Gemm # @@ -8800,7 +8953,10 @@ def GenerateSM90(manifest, cuda_version): GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version) GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version) GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version) GenerateSM90_TensorOp_1684(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) @@ -8899,6 +9055,12 @@ if __name__ == "__main__": if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) + if 'kernel_testlist_l0' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L0") + + if 'kernel_testlist_l1' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L1") + if args.selected_kernel_list is not None: if len(manifest.selected_kernels) > 0: with open(args.selected_kernel_list, 'w') as file_writer: diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index dc8a6f96..89e72f2b 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -485,19 +485,25 @@ class KernelScheduleType(enum.Enum): TmaWarpSpecialized1SmSm100 = enum_auto() TmaWarpSpecialized2SmSm100 = enum_auto() - - + + PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() + BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() - Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() + KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto() + KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto() + KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() + # KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', @@ -516,19 +522,24 @@ KernelScheduleTag = { KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', - - + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100', + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100', KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100', - KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100', KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100', KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', } # @@ -549,39 +560,54 @@ KernelScheduleSuffixes = { KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', - - + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm', KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', - KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', } class EpilogueScheduleType(enum.Enum): ScheduleAuto = enum_auto() EpilogueTransposed = enum_auto() NoSmemWarpSpecialized = enum_auto() + PtrArrayNoSmemWarpSpecialized = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecialized1Sm = enum_auto() TmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecialized1Sm = enum_auto() + PtrArrayTmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() # EpilogueScheduleTag = { EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong', } # @@ -589,10 +615,15 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.ScheduleAuto: '', EpilogueScheduleType.EpilogueTransposed: '', EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.TmaWarpSpecialized1Sm: '', EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '_tma_1sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_tma_2sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma_cooperative', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma_pingpong', } class EpilogueFunctor3x(enum.Enum): @@ -786,6 +817,7 @@ class GemmKind(enum.Enum): PlanarComplexArray = enum_auto() Grouped = enum_auto() BlockScaledUniversal3x = enum_auto() + GroupedGemmUniversal3x = enum_auto() # GemmKindNames = { @@ -797,7 +829,8 @@ GemmKindNames = { GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped", - GemmKind.BlockScaledUniversal3x: "gemm_block_scaled" + GemmKind.BlockScaledUniversal3x: "gemm_block_scaled", + GemmKind.GroupedGemmUniversal3x: "gemm_grouped", } # @@ -838,6 +871,12 @@ EpilogueFunctorTag = { EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', } +# +class MixedInputMode(enum.Enum): + ConvertOnly = enum_auto() + ScaleOnly = enum_auto() + ScaleWithZeroPoint = enum_auto() + # class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 53285400..984ba33c 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -43,7 +43,7 @@ import os.path import shutil import sys import copy -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple, List try: import builtins @@ -153,6 +153,17 @@ def generate_int8_math_instruction_shapes_sm90(level: int): ] return filtered_list_of_wgmma_shapes +def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType): + # DataTypeSize are in the unit of bits + a_bytes = DataTypeSize[a_type] // 8 + b_bytes = DataTypeSize[b_type] // 8 + if a_bytes == 4 or b_bytes == 4: + return generate_tf32_math_instruction_shapes_sm90(wgmma_level) + elif a_bytes == 2 or b_bytes == 2: + return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level) + else: + return generate_fp8_math_instruction_shapes_sm90(wgmma_level) + ########### def generate_tf32_math_instructions_sm90(level: int): @@ -219,6 +230,22 @@ def generate_fp8_math_instructions_sm90(level: int): ] return math_instructions +def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for a_type, b_type, acc_type in types_of_a_b_acc: + math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type) + for math_instruction_shape in math_instruction_shapes: + math_instructions += [ + MathInstruction( + math_instruction_shape, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ), + ] + return math_instructions + def generate_int8_math_instructions_sm90(level: int): wgmma_level = get_wgmma_level_from_global_level(level) math_instructions = [] @@ -407,7 +434,7 @@ def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types): def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout, - instantiation_level, enable_fp8_fast_acc=True): + instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x): # Level 0: prune according to existing generator.py behavior # Level >= 1: no pruning level = get_pruning_level_from_global_level(instantiation_level) @@ -428,8 +455,6 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor - is_sparse = tile_description.math_instruction.opcode_class == OpcodeClass.SparseTensorOp - can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description) can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types) @@ -464,6 +489,16 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0): return [], [] + grouped = gemm_kind == GemmKind.GroupedGemmUniversal3x + if grouped: + # the following cases are unsupported by grouped GEMM + if not is_aligned: + return [], [] + if not can_do_tma_epilogue: + return [], [] + if requires_transposed_epilogue: + return [], [] + # Early pruning if level < 1: # Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64 @@ -477,20 +512,23 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if not is_void_c or d_type not in FP8_TYPES: return [], [] if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue: - return [ + schedules = [] + if not grouped: + schedules.append( + [ + KernelScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + schedules.append( [ - KernelScheduleType.TmaWarpSpecializedCooperative, - EpilogueScheduleType.TmaWarpSpecializedCooperative - ], - [ - KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, - EpilogueScheduleType.TmaWarpSpecializedCooperative - ], - ] , [] + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum if not grouped else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecializedCooperative if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + ]) + return schedules, [] return [], [] if is_fp8 and not is_large_fp8_tile: - valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16] + valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void] # Prune all configs with fp8 source, and all configs with non-fp8 output # that have different dtypes for source and output. if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type): @@ -504,6 +542,33 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if is_void_c and not can_do_tma_epilogue: return [], [] + # For mixed input data types + a_type_size = DataTypeSize[data_types["a_type"]] + b_type_size = DataTypeSize[data_types["b_type"]] + if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1): + schedules = [] + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + schedules.append([ + KernelScheduleType.TmaWarpSpecialized, + epilogue_schedule + ]) + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule + ]) + if cta_m >= 128: + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + else: + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative + schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule + ]) + return schedules, [] + if not is_aligned: schedules = [[KernelScheduleType.CpAsyncWarpSpecialized, default_epilogue]] @@ -521,6 +586,15 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, return schedules, stream_k_schedules + if grouped: + pingpong = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum + cooperative = KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative if not is_fp8 else KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum + if can_do_tma_epilogue: + schedules.append([pingpong, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong]) + if can_do_cooperative: + schedules.append([cooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative]) + return schedules, [] + schedules = [] # Pruning: emit Void-C kernels with persistent kernels only if level >= 1 or not is_void_c: diff --git a/test/self_contained_includes/CMakeLists.txt b/test/self_contained_includes/CMakeLists.txt index 2f82812f..7e7b0498 100644 --- a/test/self_contained_includes/CMakeLists.txt +++ b/test/self_contained_includes/CMakeLists.txt @@ -198,7 +198,6 @@ set(header_files_to_check cutlass/version.h cutlass/wmma_array.h cutlass/workspace.h - cutlass/exmy_base.h cutlass/float_subbyte.h diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 0abda31d..4d10b025 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -125,6 +125,50 @@ function(cutlass_test_unit_add_executable NAME) endfunction() + +function(cutlass_test_unit_add_executable_split_file NAME) + # Given the input arguments to cutlass_test_unit_add_executable, creates + # a new set of arguments in which each file has at most one TEST definition, + # and calls cutlass_test_unit_add_executable with the newly-formed arguments. + # The goal of this is to reduce the memory consumed while building CUTLASS + # tests with a high degree of parallelism while not requiring developers + # to split unit tests across multiple files artificially. + + # Get all arguments other than the NAME of the target + list(SUBLIST ARGV 1 ${ARGC} SUBARGV) + + if (CUTLASS_UNIT_TEST_SPLIT_FILES) + execute_process( + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMAND ${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/tools/scripts/split_test_cmake.py + ${NAME} + ${CMAKE_CURRENT_SOURCE_DIR} + --src_files ${SUBARGV} + --dst_dir ${CMAKE_CURRENT_BINARY_DIR} + RESULT_VARIABLE cutlass_test_SPLIT_RESULT + OUTPUT_VARIABLE cutlass_test_SPLIT_OUTPUT + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/test_split_files.txt + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/test_split_error.log + ) + + if(NOT cutlass_test_SPLIT_RESULT EQUAL 0) + message(FATAL_ERROR "Error splitting unit test. See ${CMAKE_CURRENT_BINARY_DIR}/test_split_error.log") + endif() + + # Forward the values printed by split_test_cmake.py as arguments to cutlass_test_unit_add_executable. + # We additionally specify to add -I${CMAKE_CURRENT_SOURCE_DIR} to the target. This is necessary because + # the splitting process writes new files to ${CMAKE_CURRENT_BINARY_DIR}, but many CUTLASS unit tests + # use relative imports for including testbeds (e.g., '#include "../testbed.hpp"'). These headers are + # not written to ${CMAKE_CURRENT_BINARY_DIR} during the splitting process, so we must indicate that + # headers can also be searched for from ${CMAKE_CURRENT_SOURCE_DIR}. + file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/test_split_files.txt NEW_OPTIONS) + cutlass_test_unit_add_executable(${NAME} ${NEW_OPTIONS} EXTRA_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}) + else() + # Simply pass arguments through + cutlass_test_unit_add_executable(${ARGV}) + endif() +endfunction() + add_custom_target(cutlass_test_unit) add_custom_target(test_unit) diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index 8ea7dde8..cf489d13 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -276,11 +276,12 @@ endif() if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 89) - # Conv - F8 input, F8 output, F32 accumulation + # Conv - F8 input, F8 output cutlass_test_unit_add_executable( cutlass_test_unit_conv_device_tensorop_f8_sm89 conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu + conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f16_sm89.cu ) endif() diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f16_sm89.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f16_sm89.cu new file mode 100644 index 00000000..108e48a5 --- /dev/null +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f16_sm89.cu @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide Conv2d fprop interface with: + A: NHWC, of type FE4M4 or FE5M2 + B: NHWC, of type FE4M3 or FE5M2 + C: NHWC, of FE4M3 or FE5M2 + Accum: F16 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic_with_scaling.h" +#include "cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "conv2d_with_absmax_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Conv2d_Fprop_Analytic_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc_tensor_op_f16, + identity_128x256x64_64x3_64x64x64) { + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithAbsMax< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementOutput, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + kStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + bool passed = test::conv::device::TestAllConv2dWithAbsmax(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc_tensor_op_f16, + relu_128x256x64_64x3_64x64x64) { + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::ReLu, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithAbsMax< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementOutput, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + kStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + bool passed = test::conv::device::TestAllConv2dWithAbsmax(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc_tensor_op_f16, + identity_fastacc_128x256x64_64x3_64x64x64) { + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithAbsMax< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementOutput, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + kStages, + cutlass::arch::OpMultiplyAddFastAccum, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + bool passed = test::conv::device::TestAllConv2dWithAbsmax(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc_tensor_op_f16, + identity_noScale_128x256x64_64x3_64x64x64) { + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFpropWithAbsMax< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementOutput, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + kStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized + >::Kernel; + + using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; + + bool passed = test::conv::device::TestAllConv2dWithAbsmax( + /* scaleA = */false, + /* scaleB = */false, + /* scaleC = */false + ); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu index fe82e9ec..7f12cad3 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f8nhwc_f8nhwc_f8nhwc_tensor_op_f32_sm89.cu @@ -49,7 +49,7 @@ #include "conv2d_with_absmax_testbed.h" -#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// @@ -365,4 +365,4 @@ TEST(SM89_Device_Conv2d_Fprop_Optimized_ImplicitGemm_fe4m3nhwc_fe4mnhwc_fe4mnhwc //////////////////////////////////////////////////////////////////////////////// -#endif // CUTLASS_ARCH_MMA_SM89_SUPPORTED +#endif // CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 5fdda499..35d48212 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -30,8 +30,6 @@ add_custom_target(cutlass_test_unit_gemm_device) add_custom_target(test_unit_gemm_device) -add_subdirectory(sm100_blockscaled_tensorop_gemm) - ################################################################################ @@ -53,6 +51,12 @@ endfunction() ################################################################################ + +add_subdirectory(sm100_blockscaled_tensorop_gemm) +add_subdirectory(sm100_tensorop_gemm) + + + cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_simt @@ -548,8 +552,10 @@ cutlass_test_unit_gemm_device_add_executable( gemm_f8t_f8n_f32t_tensor_op_f32_sm89.cu gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu + gemm_f8t_f8n_f16t_tensor_op_f16_sm89.cu gemm_f8t_f8n_f8t_tensor_op_f32_sm89.cu # gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu + gemm_f8t_f8n_f8t_tensor_op_f16_sm89.cu ) cutlass_test_unit_gemm_device_add_executable( @@ -829,56 +835,6 @@ endif() if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") -cutlass_test_unit_add_executable( - cutlass_test_unit_gemm_device_sm100_fp16_gemm - - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 - - sm100_gemm_f16_f16_f32_tensor_op_f32.cu -) - -cutlass_test_unit_gemm_device_add_executable( - cutlass_test_unit_gemm_device_tensorop_sm100_stream_k - - sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu -) - -cutlass_test_unit_gemm_device_add_executable( - cutlass_test_unit_gemm_device_sm100_bf16_gemm - - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 - - sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu -) - - -cutlass_test_unit_gemm_device_add_executable( - cutlass_test_unit_gemm_device_tensorop_stride_batch_alpha_beta_sm100 - - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 - - sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu -) - -cutlass_test_unit_gemm_device_add_executable( - cutlass_test_unit_gemm_device_tensorop_runtime_datatype_sm100 - - # No batching of source to control compiler memory usage - BATCH_SOURCES ON - BATCH_SIZE 1 - - sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu - sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu - sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu - sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu -) - cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_16b_tensorop_sm100_ptr_array diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f16t_tensor_op_f16_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f16t_tensor_op_f16_sm89.cu new file mode 100644 index 00000000..2b449a02 --- /dev/null +++ b/test/unit/gemm/device/gemm_f8t_f8n_f16t_tensor_op_f16_sm89.cu @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface with: + A: row major, of type FE4M4 or FE5M2 + B: column major, of type FE4M3 or FE5M2 + C: row major, of type F16 + Accum: F16 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using Gemm = cutlass::gemm::device::Gemm< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe5m2n_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e5m2_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using Gemm = cutlass::gemm::device::Gemm< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe5m2t_fe4m3n_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementA = cutlass::float_e5m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using Gemm = cutlass::gemm::device::Gemm< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe5m2t_fe5m2n_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementA = cutlass::float_e5m2_t; + using ElementB = cutlass::float_e5m2_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using Gemm = cutlass::gemm::device::Gemm< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sm89.cu index f82f7291..9e00a3f0 100644 --- a/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sm89.cu +++ b/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sm89.cu @@ -51,7 +51,7 @@ #include "testbed.h" -#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// @@ -151,4 +151,4 @@ TEST(SM89_Device_Gemm_fe5m2t_fe5m2n_f32t_tensor_op_f32, 128x256x64_64x64x64) { //////////////////////////////////////////////////////////////////////////////// -#endif // CUTLASS_ARCH_MMA_SM89_SUPPORTED +#endif // CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu index d0b0d068..c52ffddd 100644 --- a/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu +++ b/test/unit/gemm/device/gemm_f8t_f8n_f32t_tensor_op_f32_sparse_sm89.cu @@ -51,7 +51,7 @@ #include "testbed_sparse.h" -#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// @@ -151,4 +151,4 @@ TEST(SM89_Device_Sparse_Gemm_fe5m2t_fe5m2n_f32t_tensor_op_f32, 128x128x128_64x64 //////////////////////////////////////////////////////////////////////////////// -#endif // CUTLASS_ARCH_MMA_SM89_SUPPORTED +#endif // CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f16_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f16_sm89.cu new file mode 100644 index 00000000..4b9146ab --- /dev/null +++ b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f16_sm89.cu @@ -0,0 +1,430 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface with: + A: row major, of type FE4M4 or FE5M2 + B: column major, of type FE4M3 or FE5M2 + C: row major, of FE4M3 or FE5M2 + Accum: F16 +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic_with_scaling.h" +#include "cutlass/gemm/device/gemm_universal_with_absmax.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" +#include "testbed_with_absmax.h" + +#if defined(CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f16, identity_128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f16, identity_fastacc_128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + static int const kAlignment = 16; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages, + kAlignment, kAlignment, cutlass::arch::OpMultiplyAddFastAccum + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f16, relu_128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::ReLu, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::ReLu>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe5m2n_fe4m3t_tensor_op_f16, identity_128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e5m2_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe5m2t_fe4m3n_fe4m3t_tensor_op_f16, identity_128x256x64_64x64x64) { + using ElementA = cutlass::float_e5m2_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe5m2t_fe5m2n_fe4m3t_tensor_op_f16, identity_128x256x64_64x64x64) { + using ElementA = cutlass::float_e5m2_t; + using ElementB = cutlass::float_e5m2_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe5m2t_tensor_op_f16, identity_128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e5m2_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe5m2t_fe5m2n_fe5m2t_tensor_op_f16, identity_diff_aux_output_types_128x256x64_64x64x64) { + using ElementA = cutlass::float_e5m2_t; + using ElementB = cutlass::float_e5m2_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = cutlass::float_e5m2_t; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f16, identity_128x128x64_32x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f16, identity_noScale_128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = ElementOutput; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>( + /* scaleA = */false, + /* scaleB = */false, + /* scaleC = */false + ); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f16, identity_noAux_128x256x64_64x64x64) { + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::float_e4m3_t; + using ElementAuxOutput = float; + using ElementAccumulator = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + static int const kStages = 3; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::Identity, + ElementOutput, + ElementAuxOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >; + + using Gemm = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 256, 64>, cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages + >; + + bool passed = test::gemm::device::TestAllGemmWithAbsmax, cutlass::epilogue::thread::Identity>(); + EXPECT_TRUE(passed); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_F16_SM89_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sm89.cu index 1d0bfd95..5558d680 100644 --- a/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sm89.cu +++ b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sm89.cu @@ -54,7 +54,7 @@ #include "testbed.h" #include "testbed_with_absmax.h" -#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// @@ -427,4 +427,4 @@ TEST(SM89_Device_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, identity_noAux_128x256 //////////////////////////////////////////////////////////////////////////////// -#endif // CUTLASS_ARCH_MMA_SM89_SUPPORTED +#endif // CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu index 8007ece8..79325d00 100644 --- a/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu +++ b/test/unit/gemm/device/gemm_f8t_f8n_f8t_tensor_op_f32_sparse_sm89.cu @@ -54,7 +54,7 @@ #include "testbed_sparse.h" #include "testbed_with_absmax.h" -#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) +#if defined(CUTLASS_ARCH_SPARSE_MMA_F32_SM89_SUPPORTED) //////////////////////////////////////////////////////////////////////////////// @@ -461,4 +461,4 @@ TEST(SM89_Device_Sparse_Gemm_fe4m3t_fe4m3n_fe4m3t_tensor_op_f32, identity_noAux_ //////////////////////////////////////////////////////////////////////////////// -#endif // CUTLASS_ARCH_MMA_SM89_SUPPORTED +#endif // CUTLASS_ARCH_MMA_F32_SM89_SUPPORTED diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index bf1d11fe..ea69b394 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -893,7 +893,6 @@ struct HostCollectiveMainloop>; static constexpr bool IsRowBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsColBiasEnabled = FusionOp::IsPerColBiasSupported; + static_assert(not (IsColBiasEnabled && IsRowBiasEnabled)); + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsPerColScaleEnabled = FusionOp::IsPerColScaleSupported; static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; @@ -1462,7 +1465,7 @@ struct HostCollectiveEpilogue { CheckEquality check_relative_equality = CheckEquality::EXACT; // Are scalars copied to device memory before kernel launch ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; - // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + // If vector scale is supported and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector VectorScale vector_scale_mode = VectorScale::DISABLED; // Random distribution with which to initialize the A/B/C/D/Aux scaling factors @@ -1555,8 +1558,7 @@ struct HostCollectiveEpilogue { auto col_vector_coord = cutlass::make_Coord(M); auto row_vector_coord = cutlass::make_Coord(N); auto batch_vector_coord = cutlass::make_Coord(L); - auto ML_coord = cutlass::make_Coord(M * L); - if constexpr (IsPerRowScaleEnabled) { + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { // scalars if (vector_scale_mode == VectorScale::DISABLED) { // batched scalars @@ -1581,8 +1583,9 @@ struct HostCollectiveEpilogue { } // batched vectors else { - alpha.resize(ML_coord, true); - beta.resize(ML_coord, true); + auto batched_vector_coord = cutlass::make_Coord((IsPerRowScaleEnabled ? M : N) * L); + alpha.resize(batched_vector_coord, true); + beta.resize(batched_vector_coord, true); EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); if (beta_ != ElementScalar(0)) { EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); @@ -1627,9 +1630,7 @@ struct HostCollectiveEpilogue { scale_D.sync_device(); } - if constexpr ( - IsRowBiasEnabled - ) { + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { bias.resize(IsRowBiasEnabled ? col_vector_coord : row_vector_coord); EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); bias.sync_device(); @@ -1810,7 +1811,6 @@ struct HostCollectiveEpilogue { } passed &= passed_sf; } - return passed; } @@ -1823,7 +1823,7 @@ struct HostCollectiveEpilogue { << ", scale_b: " << scale_B.at(coord_0) << ", scale_c: " << scale_C.at(coord_0); } - if constexpr (IsPerRowScaleEnabled) { + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { file << "\n\nvalpha = \n" << alpha.host_view(); file << "\n\nvbeta = \n" << beta.host_view(); } else { @@ -1853,9 +1853,10 @@ struct HostCollectiveEpilogue { file << "\n\n"; } - if constexpr (IsRowBiasEnabled) { + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { file << "\n\nBias = \n" << bias.host_view(); } + if constexpr (IsAuxInEnabled) { file << "\n\nAux Input = \n" << tensor_Aux.host_view(); } @@ -1876,7 +1877,6 @@ struct HostCollectiveEpilogue { << "\n\nSFD Reference =\n" << reference_SFD.host_view() << "\n\nSFD Computed =\n" << tensor_SFD.host_view(); } - file << "\nC =\n" << tensor_C.host_view() @@ -1921,6 +1921,12 @@ struct HostCollectiveEpilogue { fusion_args.dAlpha = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); fusion_args.dBeta = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); } + else if constexpr (IsPerColScaleEnabled) { + int32_t n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + fusion_args.dBeta = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + } else { if constexpr (not IsFfma2Kernel) { if (use_device_scalars == ScalarLoc::ON_DEVICE) { @@ -1943,9 +1949,7 @@ struct HostCollectiveEpilogue { fusion_args.scale_d_ptr = scale_D.device_data(); } - if constexpr ( - IsRowBiasEnabled - ) { + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { fusion_args.bias_ptr = bias.device_data(); } @@ -1993,7 +1997,6 @@ struct HostCollectiveEpilogue { arguments.thread.block_scale_factor_ptr = tensor_SFD.device_data(); arguments.thread.norm_constant_ptr = norm_constant.device_data(); } - } return arguments; @@ -2025,6 +2028,12 @@ struct HostCollectiveEpilogue { return cute::make_tensor(detail::make_iterator(alpha.host_data()), cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } else { return cute::make_tensor(detail::make_iterator(alpha.host_data()), cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); @@ -2038,6 +2047,12 @@ struct HostCollectiveEpilogue { return cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } else { return cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); @@ -2069,8 +2084,8 @@ struct HostCollectiveEpilogue { ActivationFunctor, decltype(SfD), Int, - cutlass::plus - , false /*PerColumnBias_*/ + cutlass::plus, + IsColBiasEnabled , SfGenStrategy > epilogue_params{}; @@ -2086,8 +2101,7 @@ struct HostCollectiveEpilogue { epilogue_params.scale_d = scale_D.at(coord_0); } - if constexpr (IsRowBiasEnabled - or IsDeBiasEnabled) + if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) { epilogue_params.Bias = Bias; } @@ -2110,7 +2124,7 @@ struct HostCollectiveEpilogue { } } - if constexpr (IsPerRowScaleEnabled) { + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { epilogue_params.Valpha = Valpha; if (vector_scale_mode == VectorScale::ENABLED) { epilogue_params.Vbeta = Vbeta; @@ -3294,9 +3308,14 @@ bool TestAll(double alpha = 1.0, double beta = cute::is_same_v testbed(check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); - int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; - std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + int max_alignment_m = std::max({Gemm::kAlignmentA, Gemm::kAlignmentC, Gemm::kAlignmentD}); + int max_alignment_n = std::max({Gemm::kAlignmentB, Gemm::kAlignmentC, Gemm::kAlignmentD}); + if constexpr (std::is_base_of_v) { + max_alignment_m = std::max(max_alignment_m, Gemm::EpilogueOutputOp::AlignmentAux); + max_alignment_n = std::max(max_alignment_n, Gemm::EpilogueOutputOp::AlignmentAux); + } + std::vector problem_size_m = {max_alignment_m, 512 - 3 * max_alignment_m}; + std::vector problem_size_n = {max_alignment_n, 512 - 2 * max_alignment_n}; if constexpr (cute::is_same_v) { @@ -3307,7 +3326,8 @@ bool TestAll(double alpha = 1.0, double beta = cute::is_same_v(typename Gemm::GemmKernel::TileShape{}); - std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + int max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_k = {max_alignment_k, TileShapeK * (Stages + 1) - max_alignment_k}; using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; std::vector decomposition_modes = {DecompositionMode::Heuristic}; @@ -3323,7 +3343,7 @@ bool TestAll(double alpha = 1.0, double beta = cute::is_same_v(alpha), diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index 1ae073c4..6d74d99b 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -1119,7 +1119,6 @@ struct HostCollectiveEpilogue { std::vector> tensors_SFD; std::vector> references_SFD; cutlass::DeviceAllocation device_tensors_SFD; - using ElementCompute = typename FusionOp::ElementCompute; using ElementScalar = typename FusionOp::ElementScalar; @@ -2205,8 +2204,7 @@ bool TestSmall(double alpha = 1.0, double beta = 1.0, static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); alignment_bits = cutlass::detail::get_input_alignment_bits(); - // For fp4 and fp6 mx kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. - + // For fp4 and fp6 QMMA kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. int alignment_input = (alignment_bits / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits / cute::sizeof_bits::value); diff --git a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt index 01e79c98..a7656fca 100644 --- a/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt +++ b/test/unit/gemm/device/sm100_blockscaled_tensorop_gemm/CMakeLists.txt @@ -46,7 +46,7 @@ add_custom_target( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf6 ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_bstensorop_sm100_nvf4xnvf4 BATCH_SOURCES ON @@ -57,7 +57,7 @@ cutlass_test_unit_add_executable( nvf4_nvf4_f16_nvfp4_epilogue.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf4 BATCH_SOURCES ON @@ -67,7 +67,7 @@ cutlass_test_unit_add_executable( mxf4_mxf4_void_f16_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf6 BATCH_SOURCES ON @@ -77,7 +77,7 @@ cutlass_test_unit_add_executable( mxf6_mxf6_void_bf16_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf8 BATCH_SOURCES ON @@ -87,7 +87,7 @@ cutlass_test_unit_add_executable( mxf8_mxf8_void_f8_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf8 BATCH_SOURCES ON @@ -97,7 +97,7 @@ cutlass_test_unit_add_executable( mxf6_mxf8_void_f32_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf6 BATCH_SOURCES ON @@ -107,7 +107,7 @@ cutlass_test_unit_add_executable( mxf8_mxf6_f16_f8_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf8 BATCH_SOURCES ON @@ -117,7 +117,7 @@ cutlass_test_unit_add_executable( mxf4_mxf8_bf16_bf16_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf8xmxf4 BATCH_SOURCES ON @@ -127,7 +127,7 @@ cutlass_test_unit_add_executable( mxf8_mxf4_f16_bf16_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf6xmxf4 BATCH_SOURCES ON @@ -137,7 +137,7 @@ cutlass_test_unit_add_executable( mxf6_mxf4_f16_f16_nt_layout.cu ) -cutlass_test_unit_add_executable( +cutlass_test_unit_gemm_device_add_executable_split_file( cutlass_test_unit_gemm_device_bstensorop_sm100_mxf4xmxf6 BATCH_SOURCES ON diff --git a/test/unit/gemm/device/sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu deleted file mode 100644 index 5cd4158e..00000000 --- a/test/unit/gemm/device/sm100_gemm_bf16_bf16_f32_tensor_op_f32.cu +++ /dev/null @@ -1,323 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass/epilogue/thread/activation.h" -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -using namespace cute; - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -/// A Row B Col -TEST(SM100_Device_Gemm_f16t_f16n_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { - using ElementA = cutlass::bfloat16_t; - using ElementB = cutlass::bfloat16_t; - using ElementC = void; - using ElementD = float; - using ElementCompute = float; - using ElementAccumulator = float; - using GmemLayoutA = cutlass::layout::RowMajor; - using GmemLayoutB = cutlass::layout::ColumnMajor; - using GmemLayoutC = cutlass::layout::RowMajor; - using ClusterTileShape_MNK = Shape<_512,_512,_128>; - using ClusterShape_MNK = Shape<_4,_4,_1>; - using MmaTileShape_MNK = Shape<_256,_128,_128>; - using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); - - // - // Construct CollectiveEpilogue - // - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, GmemLayoutC, 16, - ElementD, GmemLayoutC, 16, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - // - // Construct CollectiveMainloop - // - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, GmemLayoutA, 8, - ElementB, GmemLayoutB, 8, - ElementAccumulator, - MmaTileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmallFusion(1.0, 0); - EXPECT_TRUE(pass); -} - -/// A Col B Row -TEST(SM100_Device_Gemm_f16n_f16t_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { - using ElementA = cutlass::bfloat16_t; - using ElementB = cutlass::bfloat16_t; - using ElementC = void; - using ElementD = float; - using ElementCompute = float; - using ElementAccumulator = float; - using GmemLayoutA = cutlass::layout::ColumnMajor; - using GmemLayoutB = cutlass::layout::RowMajor; - using GmemLayoutC = cutlass::layout::RowMajor; - using ClusterTileShape_MNK = Shape<_512,_512,_128>; - using ClusterShape_MNK = Shape<_4,_4,_1>; - using MmaTileShape_MNK = Shape<_256,_128,_128>; - using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); - - // - // Construct CollectiveEpilogue - // - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, GmemLayoutC, 16, - ElementD, GmemLayoutC, 16, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - // - // Construct CollectiveMainloop - // - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, GmemLayoutA, 8, - ElementB, GmemLayoutB, 8, - ElementAccumulator, - MmaTileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmallFusion(1.0, 0); - EXPECT_TRUE(pass); -} - -/// A Row B Row -TEST(SM100_Device_Gemm_f16t_f16t_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { - using ElementA = cutlass::bfloat16_t; - using ElementB = cutlass::bfloat16_t; - using ElementC = void; - using ElementD = float; - using ElementCompute = float; - using ElementAccumulator = float; - using GmemLayoutA = cutlass::layout::RowMajor; - using GmemLayoutB = cutlass::layout::RowMajor; - using GmemLayoutC = cutlass::layout::RowMajor; - using ClusterTileShape_MNK = Shape<_512,_512,_128>; - using ClusterShape_MNK = Shape<_4,_4,_1>; - using MmaTileShape_MNK = Shape<_256,_128,_128>; - using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); - - // - // Construct CollectiveEpilogue - // - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, GmemLayoutC, 16, - ElementD, GmemLayoutC, 16, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - // - // Construct CollectiveMainloop - // - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, GmemLayoutA, 8, - ElementB, GmemLayoutB, 8, - ElementAccumulator, - MmaTileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmallFusion(1.0, 0); - EXPECT_TRUE(pass); -} - -/// A Col B Col -TEST(SM100_Device_Gemm_f16n_f16n_f32t_tensorop_2sm_f32, 512x512x128_4x4x1) { - using ElementA = cutlass::bfloat16_t; - using ElementB = cutlass::bfloat16_t; - using ElementC = void; - using ElementD = float; - using ElementCompute = float; - using ElementAccumulator = float; - using GmemLayoutA = cutlass::layout::ColumnMajor; - using GmemLayoutB = cutlass::layout::ColumnMajor; - using GmemLayoutC = cutlass::layout::RowMajor; - using ClusterTileShape_MNK = Shape<_512,_512,_128>; - using ClusterShape_MNK = Shape<_4,_4,_1>; - using MmaTileShape_MNK = Shape<_256,_128,_128>; - using OutputCtaShape = decltype(shape_div(ClusterTileShape_MNK{}, ClusterShape_MNK{})); - - // - // Construct CollectiveEpilogue - // - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, GmemLayoutC, 16, - ElementD, GmemLayoutC, 16, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - // - // Construct CollectiveMainloop - // - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, GmemLayoutA, 8, - ElementB, GmemLayoutB, 8, - ElementAccumulator, - MmaTileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmallFusion(1.0, 0); - EXPECT_TRUE(pass); -} - -TEST(SM100_Device_Gemm_bf16t_bf16t_bf32_void_f32n_tensor_op, 128x256x64_1x2x1) { - using ElementA = cutlass::bfloat16_t; - using LayoutA = cutlass::layout::RowMajor; - using ElementB = cutlass::bfloat16_t; - using LayoutB = cutlass::layout::RowMajor; - using ElementAccumulator = float; - using LayoutC = cutlass::layout::ColumnMajor; - using MmaTileShape = Shape<_128,_128,_64>; - using TileShape_MNK = Shape<_128,_256,_64>; - using ClusterShape_MNK = Shape<_1,_2,_1>; - using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - void, LayoutC, 8, - float, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - MmaTileShape, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmall(1.0, 0.0); - EXPECT_TRUE(pass); -} - -#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu deleted file mode 100644 index 7547b757..00000000 --- a/test/unit/gemm/device/sm100_gemm_f16_f16_f16_tensor_op_f32_stream_k.cu +++ /dev/null @@ -1,250 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Tests for device-wide GEMM interface with stream-K scheduling -*/ - - - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" - -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -using namespace cute; - -TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_stream_k, 128x256x64_1x2x1) { - using ElementA = cutlass::half_t; - using LayoutA = cutlass::layout::RowMajor; - using ElementB = cutlass::half_t; - using LayoutB = cutlass::layout::RowMajor; - using ElementAccumulator = float; - using LayoutC = cutlass::layout::ColumnMajor; - using TileShape_MNK = Shape<_128,_256,_64>; - using ClusterShape_MNK = Shape<_1,_2,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); - using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::half_t, LayoutC, 8, - cutlass::half_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - MmaTileShape, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::StreamKScheduler - >; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using Testbed = Testbed3x; - bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); - EXPECT_TRUE(result); -} - -TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_stream_k, 256x128x64_2x1x1) { - using ElementA = cutlass::half_t; - using LayoutA = cutlass::layout::RowMajor; - using ElementB = cutlass::half_t; - using LayoutB = cutlass::layout::RowMajor; - using ElementAccumulator = float; - using LayoutC = cutlass::layout::ColumnMajor; - using TileShape_MNK = Shape<_256,_128,_64>; - using ClusterShape_MNK = Shape<_2,_1,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); - using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::half_t, LayoutC, 8, - cutlass::half_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - MmaTileShape, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::StreamKScheduler - >; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using Testbed = Testbed3x; - bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); - EXPECT_TRUE(result); -} - -TEST(SM100_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_stream_k, 256x256x64_2x2x1) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using TileShape_MNK = Shape<_256,_256,_64>; - using ClusterShape_MNK = Shape<_2,_2,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); - using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::half_t, LayoutC, 8, - cutlass::half_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - MmaTileShape, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::StreamKScheduler - >; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using Testbed = Testbed3x; - bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); - EXPECT_TRUE(result); -} - -/////////////////////////////////////////////////////////////////////////////// - -TEST(SM100_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_stream_k, 256x128x64_2x4x1) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using TileShape_MNK = Shape<_256,_256,_64>; - using ClusterShape_MNK = Shape<_2,_4,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); - using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::half_t, LayoutC, 8, - cutlass::half_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - MmaTileShape, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue, - cutlass::gemm::StreamKScheduler - >; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using Testbed = Testbed3x; - bool result = TestSmall(1.0, 0.0, CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {64, 1024, 2048}); - EXPECT_TRUE(result); -} - -#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32.cu deleted file mode 100644 index ea7389a7..00000000 --- a/test/unit/gemm/device/sm100_gemm_f16_f16_f32_tensor_op_f32.cu +++ /dev/null @@ -1,104 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass/epilogue/thread/activation.h" -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -using namespace cute; - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -TEST(SM100_Device_Gemm_f16t_f16t_f32_void_f16n_tensor_op, 128x256x64_1x2x1) { - using ElementA = cutlass::half_t; - using LayoutA = cutlass::layout::RowMajor; - using ElementB = cutlass::half_t; - using LayoutB = cutlass::layout::RowMajor; - using ElementAccumulator = float; - using LayoutC = cutlass::layout::ColumnMajor; - using TileShape_MNK = Shape<_128,_256,_64>; - using ClusterShape_MNK = Shape<_1,_2,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape_MNK{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(TileShape_MNK{}, ClusterShape_MNK{})); - using MmaTileShape = decltype(shape_div(TileShape_MNK{}, AtomThrShape{})); - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - void, LayoutC, 8, - cutlass::half_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - MmaTileShape, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmall(1.0, 0.0); - EXPECT_TRUE(pass); -} - -#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu deleted file mode 100644 index aaf6d622..00000000 --- a/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_runtime_datatype.cu +++ /dev/null @@ -1,156 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass/epilogue/thread/activation.h" -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -using namespace cute; - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_runtime_datatype, 512x512x128_4x4x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - float, cutlass::layout::RowMajor, 4, - float, cutlass::layout::RowMajor, 4, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - float, - float, - float, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float4_t, cutlass::layout::RowMajor, 128, - cutlass::type_erased_dynamic_float4_t, cutlass::layout::ColumnMajor, 128, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E2M1, cute::UMMA::MXF8F6F4Format::E2M1); - EXPECT_TRUE(pass); - -} - - -TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - float, cutlass::layout::RowMajor, 4, - float, cutlass::layout::RowMajor, 4, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - float, - float, - float, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float4_t, cutlass::layout::RowMajor, 128, - cutlass::type_erased_dynamic_float4_t, cutlass::layout::ColumnMajor, 128, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E2M1, cute::UMMA::MXF8F6F4Format::E2M1); - EXPECT_TRUE(pass); -} - -#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu deleted file mode 100644 index a2f0971f..00000000 --- a/test/unit/gemm/device/sm100_gemm_f6_f6_f32_tensor_op_f32_runtime_datatype.cu +++ /dev/null @@ -1,156 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass/epilogue/thread/activation.h" -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -using namespace cute; - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -TEST(SM100_Device_Gemm_e3m2t_e2m3n_f32t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - float, cutlass::layout::RowMajor, 4, - float, cutlass::layout::RowMajor, 4, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - float, - float, - float, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float6_t, cutlass::layout::RowMajor, 128, - cutlass::type_erased_dynamic_float6_t, cutlass::layout::ColumnMajor, 128, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E3M2, cute::UMMA::MXF8F6F4Format::E2M3); - EXPECT_TRUE(pass); - -} - -TEST(SM100_Device_Gemm_e3m2t_e2m3n_f32t_tensorop_1sm_f32_runtime_datatype, 512x512x128_4x4x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - float, cutlass::layout::RowMajor, 4, - float, cutlass::layout::RowMajor, 4, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - float, - float, - float, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float6_t, cutlass::layout::RowMajor, 128, - cutlass::type_erased_dynamic_float6_t, cutlass::layout::ColumnMajor, 128, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E3M2, cute::UMMA::MXF8F6F4Format::E2M3); - EXPECT_TRUE(pass); - -} - -#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu deleted file mode 100644 index bdd342fe..00000000 --- a/test/unit/gemm/device/sm100_gemm_f8_f4_f32_tensor_op_f32_runtime_datatype.cu +++ /dev/null @@ -1,109 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass/epilogue/thread/activation.h" -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -using namespace cute; - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -TEST(SM100_Device_Gemm_e4m3t_e2m1n_f32t_tensorop_2sm_f32_runtime_datatype, 256x128x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - float, cutlass::layout::RowMajor, 4, - float, cutlass::layout::RowMajor, 4, - cutlass::epilogue::TmaWarpSpecialized2Sm, - - cutlass::epilogue::fusion::LinearCombination< - float, - float, - float, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, - cutlass::type_erased_dynamic_float4_t, cutlass::layout::ColumnMajor, 128, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E4M3, cute::UMMA::MXF8F6F4Format::E2M1); - EXPECT_TRUE(pass); - -} - -#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu deleted file mode 100644 index 74791e83..00000000 --- a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype.cu +++ /dev/null @@ -1,297 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" - -#include "cutlass/epilogue/thread/activation.h" -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -using namespace cute; - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_2sm_f32_runtime_datatype, 256x128x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::epilogue::TmaWarpSpecialized2Sm, - - cutlass::epilogue::fusion::LinearCombination< - cutlass::float_e4m3_t, - float, - cutlass::float_e4m3_t, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); - EXPECT_TRUE(pass); - -} - -TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - cutlass::float_e4m3_t, - float, - cutlass::float_e4m3_t, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); - EXPECT_TRUE(pass); - -} - -TEST(SM100_Device_Gemm_e4m3t_e5m2n_e4m3t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - cutlass::float_e4m3_t, - float, - cutlass::float_e4m3_t, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E4M3, cute::UMMA::MXF8F6F4Format::E5M2); - EXPECT_TRUE(pass); - -} - -TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype, 256x256x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::float_e4m3_t, cutlass::layout::RowMajor, 16, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - cutlass::float_e4m3_t, - float, - cutlass::float_e4m3_t, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E4M3, cute::UMMA::MXF8F6F4Format::E4M3); - EXPECT_TRUE(pass); - -} - -TEST(SM100_Device_Gemm_e5m2t_e5m2n_e5m2t_tensorop_2sm_f32_runtime_datatype, 256x256x128_2x2x1) { - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cute::Shape, - cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::float_e5m2_t, cutlass::layout::RowMajor, 16, - cutlass::float_e5m2_t, cutlass::layout::RowMajor, 16, - cutlass::epilogue::TmaWarpSpecialized1Sm, - - cutlass::epilogue::fusion::LinearCombination< - cutlass::float_e5m2_t, - float, - cutlass::float_e5m2_t, - float - > - - >::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 16, - cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 16, - float, - cute::Shape, - cute::Shape, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - cute::Shape, - CollectiveMainloop, - CollectiveEpilogue, - void>; - - using namespace test::gemm::device; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E5M2); - EXPECT_TRUE(pass); - -} - -#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu deleted file mode 100644 index 187d820c..00000000 --- a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_s32_batch_alpha_beta.cu +++ /dev/null @@ -1,230 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - - - -/*! \file - \brief Tests for device-wide GEMM interface -*/ - -#include - -#include "cutlass/cutlass.h" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" - -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" - -#include "../../common/cutlass_unit_test.h" - -#include "gemm_testbed_3x.hpp" - -using namespace cute; - -#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) - -/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -////////////////////////////////////////// Test Batch alpha and beta ////////////////////////////////////////// -/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1cta_s32_batch_alpha_beta, 128x64x128_1x1x1) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using ElementA = cutlass::float_e4m3_t; - using ElementB = cutlass::float_e4m3_t; - using ElementC = cutlass::float_e4m3_t; - using ElementD = cutlass::float_e4m3_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementBias = cutlass::half_t; - using ClusterTileShape = cute::Shape<_128,_64,Int<128 / sizeof(ElementA)>>; - using ClusterShape = Shape<_1,_1,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); - using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); - - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; - - using FusionOperation = cutlass::epilogue::fusion::LinearCombination< - ElementD, - ElementCompute, - ElementC, - ElementBias - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, 16 / sizeof(ElementC), - ElementD, LayoutC, 16 / sizeof(ElementD), - EpilogueSchedule, - FusionOperation - >::CollectiveOp; - - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, LayoutA, 16 / sizeof(ElementA), - ElementB, LayoutB, 16 / sizeof(ElementB), - ElementAccumulator, - MmaTileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopSchedule - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmallFusion(1.0, 1.0); // beta is [1.0, 2.0] - EXPECT_TRUE(pass); -} - -TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1sm_f32_bias_relu_batch_alpha_beta, 128x128x128_1x1x1) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using ElementA = cutlass::float_e4m3_t; - using ElementB = cutlass::float_e4m3_t; - using ElementC = cutlass::float_e4m3_t; - using ElementD = cutlass::float_e4m3_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementBias = cutlass::half_t; - using ClusterTileShape = cute::Shape<_128,_128,Int<128 / sizeof(ElementA)>>; - using ClusterShape = Shape<_1,_1,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); - using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); - - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; - using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< - cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementBias>; - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, 16 / sizeof(ElementC), - ElementD, LayoutC, 16 / sizeof(ElementD), - EpilogueSchedule, - FusionOperation - >::CollectiveOp; - - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, LayoutA, 16 / sizeof(ElementA), - ElementB, LayoutB, 16 / sizeof(ElementB), - ElementAccumulator, - MmaTileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopSchedule - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmallFusion(1.0, 0.5); // beta is [0.5, 1.5] - EXPECT_TRUE(pass); -} - -TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3n_tensorop_1sm_f32_bias_relu__batch_alpha_beta0, 128x128x128_1x1x1) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - using ElementA = cutlass::float_e4m3_t; - using ElementB = cutlass::float_e4m3_t; - using ElementC = cutlass::float_e4m3_t; - using ElementD = cutlass::float_e4m3_t; - using ElementAccumulator = float; - using ElementCompute = float; - using ElementBias = cutlass::half_t; - using ClusterTileShape = cute::Shape<_128,_128,Int<128 / sizeof(ElementA)>>; - using ClusterShape = Shape<_1,_1,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); - using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); - - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; - using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< - cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementBias>; - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, - ElementC, LayoutC, 16 / sizeof(ElementC), - ElementD, LayoutC, 16 / sizeof(ElementD), - EpilogueSchedule, - FusionOperation - >::CollectiveOp; - - using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, LayoutA, 16 / sizeof(ElementA), - ElementB, LayoutB, 16 / sizeof(ElementB), - ElementAccumulator, - MmaTileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopSchedule - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - auto pass = test::gemm::device::TestSmallFusion(1.0, -1.0); // beta is [-1.0, 0.0] - EXPECT_TRUE(pass); -} - -#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm100_tensorop_gemm/CMakeLists.txt new file mode 100644 index 00000000..daed7ed7 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/CMakeLists.txt @@ -0,0 +1,71 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# + +# + +add_custom_target( + cutlass_test_unit_gemm_device_sm100_tensorop + DEPENDS + cutlass_test_unit_gemm_device_tensorop_sm100_f16xf16 + cutlass_test_unit_gemm_device_tensorop_sm100_f8xf8 + cutlass_test_unit_gemm_device_tensorop_sm100_s8xs8 +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_tensorop_sm100_f16xf16 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + f16_f16_void_f32.cu + f16_f16_f16_f16_fusion.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_tensorop_sm100_f8xf8 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + f8_f8_void_f32.cu + f8_f8_f16_f8_fusion.cu +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_tensorop_sm100_s8xs8 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + s8_s8_void_s32.cu + s8_s8_s32_s32_fusion.cu +) + +add_subdirectory(narrow_precision) diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_f16_f16_fusion.cu b/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_f16_f16_fusion.cu new file mode 100644 index 00000000..3ebe73c2 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_f16_f16_fusion.cu @@ -0,0 +1,607 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Inference fprop fusions +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16t_f16t_tensor_op_f32, 128x128x64_1x2x1_1sm_bias_relu) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * acc + beta * C + per-row bias + // D = ReLU(Z) + using ElementBias = cutlass::half_t; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16t_f16t_tensor_op_f32, 128x128x64_1x2x1_1sm_bias_gelu) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * acc + beta * C + per-row bias + // D = GELU(Z) + using ElementBias = cutlass::half_t; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::GELU, + ElementD, + ElementCompute, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Training fprop fusions +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16t_f16t_tensor_op_f32, 128x128x64_1x2x1_1sm_bias_relu_aux) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * acc + beta * C + per-row bias + // D = ReLU(Z) + // For ReLU with uint1b_t aux, aux computes the dReLU/dZ gradient, i.e. + // Aux(i) = Z(i) >= 0 ? 1 : 0 + using ElementBias = cutlass::half_t; + using ElementAux = cutlass::uint1b_t; + using GmemLayoutAux = GmemLayoutC; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + GmemLayoutAux, + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementAux, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16t_f16t_tensor_op_f32, 128x128x64_1x2x1_1sm_bias_gelu_aux) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * acc + beta * C + per-row bias + // D = GELU(Z) + // Aux = Z + using ElementBias = cutlass::half_t; + using ElementAux = cutlass::half_t; + using GmemLayoutAux = GmemLayoutC; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + GmemLayoutAux, + cutlass::epilogue::thread::GELU, + ElementD, + ElementCompute, + ElementAux, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Backprop fusions +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16t_f16t_tensor_op_f32, 128x128x64_1x2x1_1sm_dbias_drelu) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // dY = alpha * acc + beta * C + // D = dReLU(dY, Aux) + // dBias = sum of columns of D + using ElementBias = cutlass::half_t; + using ElementAux = cutlass::uint1b_t; + using GmemLayoutAux = GmemLayoutC; + using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias< + GmemLayoutAux, + cutlass::epilogue::thread::dReLU, + ElementD, + ElementCompute, + ElementAux, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16t_f16n_f16t_f16t_tensor_op_f32, 128x128x64_1x2x1_1sm_dbias_dgelu) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::half_t; + constexpr int AlignD = 8; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // dY = alpha * acc + beta * C + // D = dGELU(dY, Aux) + // dBias = sum of columns of D + using ElementBias = cutlass::half_t; + using ElementAux = cutlass::half_t; + using GmemLayoutAux = GmemLayoutC; + using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias< + GmemLayoutAux, + cutlass::epilogue::thread::dGELU, + ElementD, + ElementCompute, + ElementAux, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu b/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu new file mode 100644 index 00000000..8d8c71b9 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32.cu @@ -0,0 +1,655 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_f16n_f16t_void_f32n_tensor_op_f32, 64x64x64_4x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_64,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_64>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy <=== NEEDS TO BE 1SM otherwise ambigous. + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16t_f16n_void_f32t_tensor_op_f32, 64x128x64_1x4x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_64>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16n_f16n_void_f32t_tensor_op_f32, 128x64x64_1x8x1_streamK) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_8,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_64>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16t_f16t_void_f32n_tensor_op_f32, 128x128x64_2x8x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_8,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_f16n_f16t_void_f32n_tensor_op_f32, 128x64x64_2x4x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_64>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_f16t_f16n_void_f32n_tensor_op_f32, 128x128x64_16x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_16,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_64>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16n_f16n_void_f32n_tensor_op_f32, 256x64x64_4x1x1) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_64>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_f16t_f16t_void_f32n_tensor_op_f32, 256x256x64_2x1x1) { + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 8; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 8; + using ElementB = cutlass::half_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_64>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_f16_f8_fusion.cu b/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_f16_f8_fusion.cu new file mode 100644 index 00000000..70b70111 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_f16_f8_fusion.cu @@ -0,0 +1,430 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Inference fprop fusions +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_f16t_e4m3t_tensor_op_f32, 128x128x128_1x2x1_1sm_bias_relu) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * scale_a * scale_b * acc + beta * scale_c * C + per-row bias + // D = scale_d * ReLU(Z) + using ElementBias = cutlass::half_t; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_f16t_f32t_tensor_op_f32, 128x128x128_1x2x1_1sm_bias_relu) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * scale_a * scale_b * acc + beta * scale_c * C + per-row bias + // D = ReLU(Z) + // scale_d is only applied if D is an fp8 type + using ElementBias = float; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Training fprop fusions +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_f16t_e4m3t_tensor_op_f32, 128x128x128_1x2x1_1sm_bias_relu_amax_aux) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = cutlass::float_e4m3_t; + constexpr int AlignD = 16; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * scale_a * scale_b * acc + beta * scale_c * C + per-row bias + // D = scale_d * ReLU(Z) + // Amax_D = max absolute value of ReLU(Z) + // Aux = Z + // scale_d and Amax_D are only computed if D is fp8 + using ElementBias = cutlass::half_t; + using ElementAmax = float; + using ElementAux = float; + using GmemLayoutAux = GmemLayoutC; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutAux, + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementAux, + ElementAmax, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_f16t_f32t_tensor_op_f32, 128x128x128_1x2x1_1sm_bias_relu_amax_aux) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = cutlass::half_t; + constexpr int AlignC = 8; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_64>; + + // Epilogue fusion operation + // Z = alpha * scale_a * scale_b * acc + beta * scale_c * C + per-row bias + // D = ReLU(Z) + // Aux = scale_aux * Z + // Amax_Aux = max absolute value of Z + // scale_aux and Amax_Aux are only computed if Aux is fp8 + using ElementBias = float; + using ElementAmax = float; + using ElementAux = cutlass::float_e4m3_t; + using GmemLayoutAux = GmemLayoutC; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutAux, + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementAux, + ElementAmax, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu b/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu new file mode 100644 index 00000000..8cebebe5 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_f32.cu @@ -0,0 +1,659 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_e4m3n_e4m3t_void_f32n_tensor_op_f32, 64x64x128_4x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e5m2n_void_f32t_tensor_op_f32, 64x128x128_1x4x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2n_e4m3n_void_f32t_tensor_op_f32, 128x64x128_1x8x1) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_8,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2t_e5m2t_void_f32n_tensor_op_f32, 128x128x128_2x8x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_8,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_e5m2n_e4m3t_void_f32n_tensor_op_f32, 128x64x128_2x4x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_void_f32n_tensor_op_f32, 128x128x128_16x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_16,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3n_e4m3n_void_f32n_tensor_op_f32, 256x64x128_4x1x1) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e4m3t_void_f32n_tensor_op_f32, 256x256x128_2x1x1_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/CMakeLists.txt b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/CMakeLists.txt new file mode 100644 index 00000000..6d69be9b --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/CMakeLists.txt @@ -0,0 +1,71 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# + +# + +add_custom_target( + cutlass_test_unit_gemm_device_sm100_tensorop_narrow_precision + DEPENDS + cutlass_test_unit_gemm_device_tensorop_sm100_f6f4xf6f4 + cutlass_test_unit_gemm_device_tensorop_sm100_f6f4xf8 + cutlass_test_unit_gemm_device_tensorop_sm100_f8xf6f4 +) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_tensorop_sm100_f6f4xf6f4 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + f6f4_f6f4_void_f32_tn_layout.cu + f6f4_f6f4_void_f32_nn_layout.cu + f6f4_f6f4_void_f32_nt_layout.cu + f6f4_f6f4_void_f32_tt_layout.cu + ) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_tensorop_sm100_f6f4xf8 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + f6f4_f8_void_f32_tn_layout.cu + f6f4_f8_void_f32_nt_layout.cu + ) + +cutlass_test_unit_gemm_device_add_executable_split_file( + cutlass_test_unit_gemm_device_tensorop_sm100_f8xf6f4 + + BATCH_SOURCES ON + BATCH_SIZE 1 + + f8_f6f4_void_f32_tn_layout.cu + f8_f6f4_void_f32_nt_layout.cu + ) diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu new file mode 100644 index 00000000..2e1199ed --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nn_layout.cu @@ -0,0 +1,687 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f6f4}x{f6f4} Gemm + + * A tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * B tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_e2m1n_e2m3n_void_f32n_tensor_op_f32, 128x64x128_4x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2n_e2m1n_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e2m1n_void_f32n_tensor_op_f32, 128x192x128_2x4x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3n_e3m2n_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2n_e3m2n_void_f32n_tensor_op_f32, 256x64x128_4x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e2m1n_void_f32n_tensor_op_f32, 256x128x128_2x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e2m3n_void_f32n_tensor_op_f32, 256x192x128_2x4x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e2m1n_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu new file mode 100644 index 00000000..78ac8975 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_nt_layout.cu @@ -0,0 +1,310 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f6f4}x{f6f4} Gemm + + * A tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * B tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_e2m1n_e2m3t_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e2m1t_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2n_e2m1t_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu new file mode 100644 index 00000000..6de91c7f --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tn_layout.cu @@ -0,0 +1,1283 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f6f4}x{f6f4} Gemm + + * A tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * B tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 64x64x128_4x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e2m3n_void_f32n_tensor_op_f32, 64x128x128_2x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 64x192x128_2x4x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2t_e3m2n_void_f32n_tensor_op_f32, 64x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 128x64x128_4x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e3m2n_void_f32n_tensor_op_f32, 128x192x128_2x4x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 128x64x128_4x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 128x128x128_2x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e2m1n_void_f32n_tensor_op_f32, 128x192x128_2x4x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 128x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 256x64x128_4x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e3m2n_void_f32n_tensor_op_f32, 256x128x128_2x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_void_f32n_tensor_op_f32, 256x192x128_2x4x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e3m2n_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu new file mode 100644 index 00000000..23ea1e75 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f6f4_void_f32_tt_layout.cu @@ -0,0 +1,536 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f6f4}x{f6f4} Gemm + + * A tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * B tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | N | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +TEST(SM100Only_Device_Gemm_e2m1t_e2m1t_void_f32n_tensor_op_f32, 64x128x128_2x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e2m3t_void_f32n_tensor_op_f32, 64x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1t_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2t_e3m2t_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m1t_void_f32n_tensor_op_f32, 128x256x128_2x2x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e2m3t_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu new file mode 100644 index 00000000..38f838ba --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_nt_layout.cu @@ -0,0 +1,686 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f6f4}xf8 Gemm + + * A tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * B tensor: + * Types: {e5m2,e4m3} + * Alignment: 16 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +TEST(SM100Only_Device_Gemm_e2m1n_e4m3t_void_f32n_tensor_op_f32, 128x64x128_4x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3n_e5m2t_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3n_e4m3t_void_f32n_tensor_op_f32, 128x192x128_2x4x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2n_e5m2t_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e4m3t_void_f32n_tensor_op_f32, 256x64x128_4x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e5m2t_void_f32n_tensor_op_f32, 256x128x128_2x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1n_e4m3t_void_f32n_tensor_op_f32, 256x192x128_2x4x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3n_e4m3t_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu new file mode 100644 index 00000000..9a075b22 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f6f4_f8_void_f32_tn_layout.cu @@ -0,0 +1,1280 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f6f4}xf8 Gemm + + * A tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * B tensor: + * Types: {e5m2,e4m3} + * Alignment: 16 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | N | N | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_e2m1t_e4m3n_void_f32n_tensor_op_f32, 64x64x128_4x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e5m2n_void_f32n_tensor_op_f32, 64x128x128_2x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_e3m2t_e4m3n_void_f32n_tensor_op_f32, 64x192x128_2x4x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e5m2n_void_f32n_tensor_op_f32, 64x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2t_e4m3n_void_f32n_tensor_op_f32, 128x64x128_4x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e5m2n_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e4m3n_void_f32n_tensor_op_f32, 128x192x128_2x4x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e5m2n_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m3t_e4m3n_void_f32n_tensor_op_f32, 128x64x128_4x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m3_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e5m2n_void_f32n_tensor_op_f32, 128x128x128_2x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e5m2n_void_f32n_tensor_op_f32, 128x192x128_2x4x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2t_e5m2n_void_f32n_tensor_op_f32, 128x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e4m3n_void_f32n_tensor_op_f32, 256x64x128_4x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e3m2t_e5m2n_void_f32n_tensor_op_f32, 256x128x128_2x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e3m2_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e4m3n_void_f32n_tensor_op_f32, 256x192x128_2x4x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e4m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e2m1t_e5m2n_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e2m1_t; + constexpr int AlignA = 128; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = cutlass::float_e5m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu new file mode 100644 index 00000000..dbc55fe9 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_nt_layout.cu @@ -0,0 +1,538 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f8}x{f6f4} Gemm + + * A tensor: + * Types: {e5m2,e4m3} + * Alignment: 16 elements + * B tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +TEST(SM100Only_Device_Gemm_e4m3n_e2m3t_void_f32n_tensor_op_f32, 64x128x128_2x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2n_e3m2t_void_f32n_tensor_op_f32, 64x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3n_e2m1t_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2n_e2m3t_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3n_e3m2t_void_f32n_tensor_op_f32, 128x256x128_2x2x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2n_e2m1t_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu new file mode 100644 index 00000000..52d33db4 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/narrow_precision/f8_f6f4_void_f32_tn_layout.cu @@ -0,0 +1,1285 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Unit tests for {f6f4}x{f6f4} Gemm + + * A tensor: + * Types: {e5m2,e4m3} + * Alignment: 16 elements + * B tensor: + * Types: {e2m1,e2m3,e3m2} + * Alignment: 128 elements + * Mma Tile Shapes supported: + Support Matrix (Y: Yes, N: No) + | 1/2 SM | Mma Tile Shape | TN | TT | NT | NN | Dispatch Policy | + |--------|----------------|----|----|----|----|------------------------------------| + | 1SM | 64x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 64x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x128x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 1SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized1SmSm100` | + | 2SM | 128x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 128x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x64x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x128x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x192x128 | Y | N | N | Y | `KernelTmaWarpSpecialized2SmSm100` | + | 2SM | 256x256x128 | Y | Y | Y | Y | `KernelTmaWarpSpecialized2SmSm100` | +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../../common/cutlass_unit_test.h" +#include "../../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_e4m3t_e2m1n_void_f32n_tensor_op_f32, 64x64x128_4x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2t_e2m3n_void_f32n_tensor_op_f32, 64x128x128_2x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e2m1n_void_f32n_tensor_op_f32, 64x192x128_2x4x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e3m2n_void_f32n_tensor_op_f32, 64x256x128_2x2x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2t_e2m1n_void_f32n_tensor_op_f32, 128x64x128_4x1x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e2m3n_void_f32n_tensor_op_f32, 128x128x128_2x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2t_e3m2n_void_f32n_tensor_op_f32, 128x192x128_2x4x1_1sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e2m3n_void_f32n_tensor_op_f32, 128x256x128_2x2x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e3m2n_void_f32n_tensor_op_f32, 128x64x128_4x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2t_e2m3n_void_f32n_tensor_op_f32, 128x128x128_2x1x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e2m1n_void_f32n_tensor_op_f32, 128x192x128_2x4x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e3m2n_void_f32n_tensor_op_f32, 128x256x128_2x2x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e3m2_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2t_e2m1n_void_f32n_tensor_op_f32, 256x64x128_4x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e5m2t_e2m1n_void_f32n_tensor_op_f32, 256x128x128_2x1x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e5m2_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e2m3n_void_f32n_tensor_op_f32, 256x192x128_2x4x1_2sm) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m3_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_192,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_192,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_e4m3t_e2m1n_void_f32n_tensor_op_f32, 256x256x128_2x2x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 128; + using ElementB = cutlass::float_e2m1_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_s32_s32_fusion.cu b/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_s32_s32_fusion.cu new file mode 100644 index 00000000..09d31e57 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_s32_s32_fusion.cu @@ -0,0 +1,226 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_s8t_s8n_s32t_s32t_tensor_op_f32, 128x128x128_1x2x1_1sm_rowscale_bias_relu) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = int32_t; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // Epilogue fusion operation + // Z = per-row alpha * acc + per-row beta * C + per-row bias + // D = ReLU(Z) + using ElementBias = int32_t; + using FusionOperation = cutlass::epilogue::fusion::PerRowLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_s8t_s8n_s32t_s32t_tensor_op_f32, 128x128x128_1x2x1_1sm_colscale_bias_relu) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = int32_t; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = float; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_2,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // Epilogue fusion operation + // Z = per-col alpha * acc + per-col beta * C + per-col bias + // D = ReLU(Z) + using ElementBias = int32_t; + using FusionOperation = cutlass::epilogue::fusion::PerColLinCombPerColBiasEltAct< + cutlass::epilogue::thread::ReLU, + ElementD, + ElementCompute, + ElementBias, + ElementC>; + + // + // Construct CollectiveEpilogue + // + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy + FusionOperation // Epilogue fusion operation + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu b/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu new file mode 100644 index 00000000..3ee097a5 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/s8_s8_void_s32.cu @@ -0,0 +1,659 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100Only_Device_Gemm_s8n_s8t_void_s32n_tensor_op_f32, 64x64x128_4x1x1_1sm_streamK) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_s8t_s8n_void_s32t_tensor_op_f32, 64x128x128_1x4x1_1sm) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_s8n_s8n_void_s32t_tensor_op_f32, 128x64x128_1x8x1_streamK) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::RowMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::RowMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_1,_8,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_s8t_s8t_void_s32n_tensor_op_f32, 128x128x128_2x8x1_1sm) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_8,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_s8n_s8t_void_s32n_tensor_op_f32, 128x64x128_2x4x1_2sm_streamK) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_4,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + + +TEST(SM100Only_Device_Gemm_s8t_s8n_void_s32n_tensor_op_f32, 128x128x128_16x1x1_2sm) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_128,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_16,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_64,_128,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_s8n_s8n_void_s32n_tensor_op_f32, 256x64x128_4x1x1_streamK) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::ColumnMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_64,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_64,_128>; + + // Tile Scheduler + using TileScheduler = cutlass::gemm::StreamKScheduler; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +TEST(SM100Only_Device_Gemm_s8t_s8t_void_s32n_tensor_op_f32, 256x256x128_2x1x1) { + // Describe A and B tensors + using ElementA = int8_t; + constexpr int AlignA = 16; + using GmemLayoutA = cutlass::layout::RowMajor; + constexpr int AlignB = 16; + using ElementB = int8_t; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 4; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = int32_t; + constexpr int AlignD = 4; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = int32_t; + // Epilogue computation's precision type + using ElementCompute = int32_t; + + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_256,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_1,_1>; + // Collective Epilogue takes the output tile shape for 1 CTA + using PerSmTileShape_MNK = Shape<_128,_256,_128>; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} +#endif diff --git a/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu b/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu index ae2fa4e2..29dd176b 100644 --- a/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu +++ b/test/unit/pipeline/pipeline_cluster_launch_control_async_warp_specialized_blackwell.cu @@ -282,100 +282,100 @@ struct PipelineTest { #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) //Cluster1x2 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster1x2_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<1, 2, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster2x1 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster2x1_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<2, 1, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster2x2 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster2x2_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<2, 2, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster1x1 Stage3 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster1x1_Stage3) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<1, 1, 1>; static constexpr uint32_t Stages = 3; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster1x4 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster1x4_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<1, 4, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster4x1 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster4x1_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<4, 1, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster2x4 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster2x4_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<2, 4, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster4x2 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster4x2_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<4, 2, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } //Cluster4x4 Stage4 TEST(SM100_Verify_PipelineClusterLaunchControlAsync_WS, Cluster4x4_Stage4) { - Options options; + OptionsClusterLaunch options; options.grid_dim = {32,32,1}; using ClusterShape = cutlass::gemm::GemmShape<4, 4, 1>; static constexpr uint32_t Stages = 4; using Test = PipelineTest; - Testbed testbed(options); + TestbedClusterLaunch testbed(options); EXPECT_TRUE(testbed.verification()); } #endif diff --git a/test/unit/pipeline/testbed_cluster_launch_control.h b/test/unit/pipeline/testbed_cluster_launch_control.h index 4ac892de..49f65d8a 100644 --- a/test/unit/pipeline/testbed_cluster_launch_control.h +++ b/test/unit/pipeline/testbed_cluster_launch_control.h @@ -51,7 +51,7 @@ #include "cutlass/util/command_line.h" // Command line test options -struct Options { +struct OptionsClusterLaunch { // // Data Members // @@ -95,10 +95,10 @@ struct Options { // template -class Testbed { +class TestbedClusterLaunch { private: // Commandline options - Options options; + OptionsClusterLaunch options; bool run_test() { @@ -114,7 +114,7 @@ private: public: - Testbed(Options const &options_) : options(options_) { + TestbedClusterLaunch(OptionsClusterLaunch const &options_) : options(options_) { int device_id = 0; cudaDeviceProp device_prop; CUTE_CHECK_ERROR(cudaSetDevice(device_id)); diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index bfe1b5f4..febce464 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -222,8 +222,8 @@ cutlass_add_cutlass_library( # files split for parallel compilation src/reference/gemm_int4.cu - src/reference/block_scaled_gemm_fp4a_vs16.cu - src/reference/block_scaled_gemm_fp4a_vs32.cu + src/reference/block_scaled_gemm_fp4a_vs16.cu + src/reference/block_scaled_gemm_fp4a_vs32.cu src/reference/block_scaled_gemm_mixed8bitsa.cu src/reference/gemm_f4_f4_f32.cu src/reference/gemm_f4_f6_f32.cu diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 0309ec31..0aea3126 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -43,7 +43,8 @@ computational overhead */ -#pragma once +#ifndef CUTLASS_LIBRARY_LIBRARY_H +#define CUTLASS_LIBRARY_LIBRARY_H ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -103,7 +104,7 @@ public: void *device_workspace = nullptr, cudaStream_t stream = nullptr) const = 0; - // Originally designed for metadata, but should be useful for FP8/6/4 too. + // Originally designed for metadata, but should be useful for FP8/6/4 too. virtual Status initialize_with_profiler_workspace( void const *configuration, void *host_workspace, @@ -282,6 +283,11 @@ struct GemmUniversalConfiguration { int device_count{1}; }; +enum class Sm90MixedInputWiderOperand { + A = 0, + B = 1 +}; + struct GemmUniversalArguments { // NOTE: these are replicated for 3.0 interfaces gemm::GemmCoord problem_size{}; @@ -317,6 +323,18 @@ struct GemmUniversalArguments { int swizzle_size{1}; int split_k_slices{1}; + // For mixed input dtype kernels + bool is_mixed_dtype{false}; + Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B}; + bool generate_scale_and_zero{false}; + bool generate_dequantized_AB{false}; + bool *dequantized_AB_ready{nullptr}; // Carry the info back to gemm_operation_profiler.cu + void *Scale{nullptr}; // Scale tensor + void *Zero{nullptr}; // Zero tensor + void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + void *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + int device_index{0}; bool use_pdl{false}; @@ -472,12 +490,16 @@ struct GemmPlanarComplexArrayArguments { struct GemmGroupedConfiguration { int problem_count{0}; - int threadblock_count{0}; + // GemmGroupedConfiguration is passed to initialize(), which + // is responsible for allocating the device-side stride storage. + int64_t* lda; + int64_t* ldb; + int64_t* ldc; }; struct GemmGroupedArguments { - - gemm::GemmCoord *problem_sizes{nullptr}; + int problem_count{}; + gemm::GemmCoord* problem_sizes{nullptr}; void * ptr_A{nullptr}; void * ptr_B{nullptr}; @@ -493,6 +515,18 @@ struct GemmGroupedArguments { void const *beta{nullptr}; ScalarPointerMode pointer_mode{}; bool use_pdl{false}; + + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + + // these should really be in the configuration but staying consistent with GEMM + int sm_count{0}; + // The user is responsible for allocating storage for problem sizes. + // Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we + // unfortunately need to have both options in this struct, and the + // underlying operation uses the one it needs. + cute::Shape* problem_sizes_3x; + cute::Shape* problem_sizes_3x_host; }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -880,3 +914,5 @@ struct ReductionArguments { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/tools/library/include/cutlass/library/types.h b/tools/library/include/cutlass/library/types.h index 4b0e36fe..ebc0b1bd 100644 --- a/tools/library/include/cutlass/library/types.h +++ b/tools/library/include/cutlass/library/types.h @@ -142,7 +142,7 @@ enum class Provider { /// Enumeration indicating the kind of operation enum class OperationKind { kGemm, - kBlockScaledGemm, + kBlockScaledGemm, kRankK, kRank2K, kTrmm, @@ -152,6 +152,7 @@ enum class OperationKind { kEqGemm, kSparseGemm, kReduction, + kGroupedGemm, kInvalid }; @@ -270,7 +271,6 @@ enum class RuntimeDatatype { kStatic, kE4M3, kE5M2, - kE3M2, kE2M3, kE2M1, diff --git a/tools/library/include/cutlass/library/util.h b/tools/library/include/cutlass/library/util.h index bf763e15..f5374217 100644 --- a/tools/library/include/cutlass/library/util.h +++ b/tools/library/include/cutlass/library/util.h @@ -34,7 +34,8 @@ \brief Utilities accompanying the CUTLASS library for interacting with Library types. */ -#pragma once +#ifndef CUTLASS_LIBRARY_UTIL_H +#define CUTLASS_LIBRARY_UTIL_H #include "cutlass/cutlass.h" #include "cutlass/library/library.h" @@ -213,6 +214,63 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << std::endl; \ + return Status::kInvalid; \ + } \ + } while (0) + +// RAII CUDA buffer container +class CudaBuffer { +public: + CudaBuffer() : size_(0), d_ptr_(nullptr) {} + + explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) { + cudaError_t err = cudaMalloc(&d_ptr_, size_); + if (err != cudaSuccess) { + throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err))); + } + } + + ~CudaBuffer() { + if (d_ptr_) { + cudaFree(d_ptr_); + } + } + + CudaBuffer(CudaBuffer const&) = delete; + CudaBuffer& operator=(CudaBuffer const&) = delete; + + CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) { + other.d_ptr_ = nullptr; + other.size_ = 0; + } + + CudaBuffer& operator=(CudaBuffer&& other) noexcept { + if (this != &other) { + if (d_ptr_) { + cudaFree(d_ptr_); + } + d_ptr_ = other.d_ptr_; + size_ = other.size_; + other.d_ptr_ = nullptr; + other.size_ = 0; + } + return *this; + } + + void* data() const noexcept { return d_ptr_; } + size_t size() const noexcept { return size_; } + +private: + size_t size_; + void* d_ptr_; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library @@ -220,3 +278,4 @@ NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); ///////////////////////////////////////////////////////////////////////////////////////////////// +#endif diff --git a/tools/library/src/block_scaled_gemm_operation_3x.hpp b/tools/library/src/block_scaled_gemm_operation_3x.hpp index b95f72ec..d1cd3517 100644 --- a/tools/library/src/block_scaled_gemm_operation_3x.hpp +++ b/tools/library/src/block_scaled_gemm_operation_3x.hpp @@ -73,7 +73,7 @@ public: using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; using Sm100BlkScaledConfig = typename CollectiveMainloop::Sm100BlkScaledConfig; - + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; using ElementSFD = cute::conditional_t; diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 0ff45dae..92f62bf4 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -1201,25 +1201,30 @@ public: GemmOperationBase(name) { this->description_.gemm_kind = GemmKind::kGrouped; + this->description_.kind = OperationKind::kGroupedGemm; + this->threadblock_count = Operator::sufficient(); } +private: + int threadblock_count; + protected: /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( + Status construct_arguments_( OperatorArguments &op_args, - GemmGroupedConfiguration const *config) { + GemmGroupedConfiguration const *config) const { op_args.problem_count = config->problem_count; - op_args.threadblock_count = config->threadblock_count; + op_args.threadblock_count = threadblock_count; return Status::kSuccess; } /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( + Status update_arguments_( OperatorArguments &op_args, - GemmGroupedArguments const *arguments) { + GemmGroupedArguments const *arguments) const { if (arguments->pointer_mode == ScalarPointerMode::kHost) { @@ -1243,6 +1248,8 @@ protected: return Status::kErrorInvalidProblem; } + op_args.threadblock_count = threadblock_count; + op_args.problem_count = arguments->problem_count; op_args.problem_sizes = arguments->problem_sizes; op_args.ptr_A = static_cast(arguments->ptr_A); diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index 91f579bf..704edad3 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -36,9 +36,17 @@ #include "cutlass/cutlass.h" #include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" #include "cutlass/library/library.h" #include "library_internal.h" #include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" #include /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -65,7 +73,7 @@ public: using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; -private: +protected: GemmDescription description_; public: @@ -178,7 +186,23 @@ public: /// Constructor GemmUniversal3xOperation(char const *name = "unknown_gemm"): - GemmOperation3xBase(name, GemmKind::kUniversal) {} + GemmOperation3xBase(name, GemmKind::kUniversal) { + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + dim3 cluster_dims( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + } + } + +private: + int max_active_clusters{}; protected: @@ -227,10 +251,119 @@ protected: } }; - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( + template class Policy, int Stages, class ClusterShape, class KernelSchedule> + static constexpr bool is_mixed_dtype_mainloop_(Policy policy) { + return (cute::is_same_v, + cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>); + } + + template + static constexpr bool is_mixed_dtype_mainloop_(DispatchPolicy) { + return false; + } + + template < + typename ElementWide, + typename ElementNarrow, + typename ElementScaleMainloop, + class ActualStrideAB, + Sm90MixedInputWiderOperand wider_operand, + bool is_n4w8, + typename ElementScale, + typename ElementZero, + class Layout_SZ> + static void dequantize_encode_( OperatorArguments &operator_args, - GemmUniversalArguments const *arguments) { + GemmUniversalArguments const *arguments, + cudaStream_t stream, + const int &problem_mn, + const int &problem_k, + const int &options_l, + const int &options_g, + ElementScale *ptr_S, + ElementZero *ptr_Z, + const size_t &SZ_size, + Layout_SZ layout_SZ + ) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + auto *ptr_dequantized_AB = static_cast(arguments->dequantized_AB); + const ElementNarrow *ptr_AB = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + ptr_AB = static_cast(arguments->B); + } + else { + ptr_AB = static_cast(arguments->A); + } + dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream); + if constexpr(is_n4w8) { + size_t AB_size = cute::size(layout_AB); + cutlass::int4b_t *encoded_AB = static_cast(arguments->encoded_AB); + unified_encode_int4b(ptr_AB, encoded_AB, AB_size); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = static_cast(encoded_AB); + } + else { + operator_args.mainloop.ptr_A = static_cast(encoded_AB); + } + ElementScaleMainloop *ptr_packed_Scale = static_cast(arguments->packed_Scale); + pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size); + } + } + + template < + typename ElementAB, + class ActualStrideAB, + class LayoutAB_Reordered, + class LayoutAtomQuant, + Sm90MixedInputWiderOperand wider_operand> + static void handle_shuffle_tensor_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + const int &problem_mn, + const int &problem_k, + const int &options_l) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.dB = layout_AB_reordered; + } + else { + operator_args.mainloop.dA = layout_AB_reordered; + } + if (arguments->generate_dequantized_AB) { + size_t AB_size = cute::size(layout_AB); + ElementAB *AB_reordered = cutlass::device_memory::allocate(AB_size); + const ElementAB *AB_src = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + AB_src = static_cast(operator_args.mainloop.ptr_B); + } + else { + AB_src = static_cast(operator_args.mainloop.ptr_A); + } + reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered); + ElementAB *AB_dst = static_cast(arguments->encoded_AB); + cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size); + cutlass::device_memory::free(AB_reordered); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = AB_dst; + } + else { + operator_args.mainloop.ptr_A = AB_dst; + } + } + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments& operator_args, + GemmUniversalArguments const* arguments, + cudaStream_t stream = nullptr) const { Status status = Status::kSuccess; status = UpdateFusionArgs::update_( @@ -286,24 +419,173 @@ protected: operator_args.epilogue.ptr_C = static_cast(arguments->C); operator_args.epilogue.ptr_D = static_cast(arguments->D); - operator_args.mainloop.dA = cute::make_int_tuple_from( + // Stride{A,B} is a Layout if and only if: + // (1) This is a mixed dtype kernel, and + // (2) This mixed dtype kernel is using shuffling, and + // (3) sizeof(narrow_type) == 4 or 8 bits, and + // (4) sizeof(wide_type) == 16 bits. + // If A/B has the narrow data type, Stride{A/B} will be a Layout + constexpr bool is_StrideA_Layout = cute::is_layout::value; + constexpr bool is_StrideB_Layout = cute::is_layout::value; + static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout"); + if constexpr(!is_StrideA_Layout) { + operator_args.mainloop.dA = cute::make_int_tuple_from( arguments->lda, arguments->batch_stride_A); - operator_args.mainloop.dB = cute::make_int_tuple_from( + } + if constexpr(!is_StrideB_Layout) { + operator_args.mainloop.dB = cute::make_int_tuple_from( arguments->ldb, arguments->batch_stride_B); + } operator_args.epilogue.dC = cute::make_int_tuple_from( arguments->ldc, arguments->batch_stride_C); operator_args.epilogue.dD = operator_args.epilogue.dC; + using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy; + if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{})) { + int problem_m = arguments->problem_size.m(); + int problem_n = arguments->problem_size.n(); + int problem_k = arguments->problem_size.k(); + int options_l = arguments->batch_count; + + constexpr Sm90MixedInputWiderOperand wider_operand = + (cutlass::sizeof_bits::value > cutlass::sizeof_bits::value) ? + Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B; + using ElementWide = std::conditional_t; + using ElementNarrow = std::conditional_t; + + constexpr bool has_scale = !std::is_same_v; + constexpr bool has_zero = !std::is_same_v; + if constexpr(has_scale) { + int options_g = problem_k; + int scale_k = (problem_k + options_g - 1) / options_g; + + constexpr bool is_A4B8 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_A8B4 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4; + + // In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element + using ElementScaleMainloop = typename CollectiveMainloop::ElementScale; + using ElementScale = typename UnderlyingElement::type; + using StrideS = typename CollectiveMainloop::StrideScale; + // In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S + using ElementZero = std::conditional_t< + has_zero, + typename CollectiveMainloop::ElementZero, + ElementScale + >; + const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m; + const size_t SZ_size = static_cast(SZ_1st_dim * scale_k * options_l); + auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l); + ElementScale *ptr_S = static_cast(arguments->Scale); + ElementZero *ptr_Z = static_cast(arguments->Zero); + + // 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled + if (arguments->generate_scale_and_zero) { + // Need to fix max_dequant_val and min_dequant_val? + const float elt_max_f = float(cutlass::platform::numeric_limits::max()); + const float max_dequant_val = elt_max_f * 0.25f; + const float min_dequant_val = 0.5f; + const float scale_max = max_dequant_val / elt_max_f; + const float scale_min = min_dequant_val / elt_max_f; + uint64_t seed = 2023; + cutlass::reference::device::BlockFillRandomUniform( + ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min)); + + // In ScaleOnly mode, set Z as zero for generating dequantized A or B + const float zero_max = has_zero ? 2.0f : 0.0f; + const float zero_min = has_zero ? -2.0f : 0.0f; + cutlass::reference::device::BlockFillRandomUniform( + ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min)); + } // End of "if (arguments->generate_scale_and_zero)" + + // 2. Generate the dequantized A or B for verification + if (arguments->generate_dequantized_AB) { + StrideS stride_SZ = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ); + auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + if constexpr(is_StrideB_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of B later + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideB = typename CollectiveMainloop::StrideB; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } + else { + if constexpr(is_StrideA_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of A later + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideA = typename CollectiveMainloop::StrideA; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)" + arguments->dequantized_AB_ready[0] = true; + } // End of "if (arguments->generate_dequantized_AB)" + + // 3. Put arguments in mainloop + if constexpr(is_int4_x_fp8) { + operator_args.mainloop.ptr_S = static_cast(arguments->packed_Scale); + } + else { + operator_args.mainloop.ptr_S = static_cast(arguments->Scale); + } + operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideS{}, shape_SZ); + operator_args.mainloop.group_size = options_g; + if constexpr(has_zero) { + operator_args.mainloop.ptr_Z = static_cast(arguments->Zero); + } + } // End of "if constexpr(has_scale)" + + // Handle the shuffling + using ValueShuffle = std::conditional_t< + cutlass::sizeof_bits::value == 4, + cute::Layout, cute::Stride>, + cute::Layout, cute::Stride> + >; + constexpr int NumShuffleAtoms = 1; + using MmaAtomShape = cute::Layout>>; + using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout and stride of A/B later + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) { + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + using LayoutB_Reordered = typename CollectiveMainloop::StrideB; + handle_shuffle_tensor_( + operator_args, arguments, problem_n, problem_k, options_l); + } + if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) { + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + using LayoutA_Reordered = typename CollectiveMainloop::StrideA; + handle_shuffle_tensor_( + operator_args, arguments, problem_m, problem_k, options_l); + } + } // End of "if constexpr(is_mixed_dtype_mainloop_(MainloopPolicy{}))" + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ operator_args.hw_info.sm_count = arguments->sm_count; if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { - dim3 cluster_dims(cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), - cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); - uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; - void const* kernel_ptr = (void*)(device_kernel); - operator_args.hw_info.max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( - cluster_dims, threads_per_block, kernel_ptr); + operator_args.hw_info.max_active_clusters = max_active_clusters; } if constexpr (!std::is_const_v) { operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; @@ -356,7 +638,10 @@ public: return status; } - return Operator::can_implement(args); + Status can_impl = Operator::can_implement(args); + + //return Operator::can_implement(args); + return can_impl; } /// Gets the host-side workspace @@ -397,7 +682,7 @@ public: cudaStream_t stream = nullptr) const override { OperatorArguments args; - Status status = update_arguments_(args, static_cast(arguments_ptr)); + Status status = update_arguments_(args, static_cast(arguments_ptr), stream); if (status != Status::kSuccess) { return status; } diff --git a/tools/library/src/grouped_gemm_operation_3x.hpp b/tools/library/src/grouped_gemm_operation_3x.hpp new file mode 100644 index 00000000..a07ce63b --- /dev/null +++ b/tools/library/src/grouped_gemm_operation_3x.hpp @@ -0,0 +1,330 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all grouped GEMM operations in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "gemm_operation_3x.hpp" +#include "library_internal.h" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/// **** CAUTION **** +/// Unlike other operations, initialize() must be called when +/// certain arguments change. See initialize() for details. +template +class GroupedGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + +private: + mutable CudaBuffer strideA_device; + mutable CudaBuffer strideB_device; + mutable CudaBuffer strideC_device; + mutable CudaBuffer strideD_device; + mutable std::vector strideA_host; + mutable std::vector strideB_host; + mutable std::vector strideC_host; + mutable std::vector strideD_host; + +public: + GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GemmOperation3xBase(name, GemmKind::kGrouped) { + this->description_.kind = OperationKind::kGroupedGemm; + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster_dims( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + } + } + + ~GroupedGemmUniversal3xOperation() override = default; + +private: + int max_active_clusters{}; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + Status + update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const { + + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped; + operator_args.problem_shape = { + arguments->problem_count, + arguments->problem_sizes_3x, + arguments->pointer_mode == ScalarPointerMode::kHost ? arguments->problem_sizes_3x_host + : nullptr}; + operator_args.mainloop.ptr_A = + static_cast(arguments->ptr_A); + operator_args.mainloop.ptr_B = + static_cast(arguments->ptr_B); + operator_args.epilogue.ptr_C = + static_cast(arguments->ptr_C); + operator_args.epilogue.ptr_D = static_cast(arguments->ptr_D); + + operator_args.mainloop.dA = + static_cast(strideA_device.data()); + operator_args.mainloop.dB = + static_cast(strideB_device.data()); + operator_args.epilogue.dC = + static_cast(strideC_device.data()); + operator_args.epilogue.dD = + static_cast(strideD_device.data()); + + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + operator_args.hw_info.max_active_clusters = max_active_clusters; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + + return status; + } + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GemmGroupedArguments const* arguments = static_cast(arguments_ptr); + OperatorArguments args; + + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const* configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + + auto num_groups = config.problem_count; + strideA_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups); + strideB_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups); + strideC_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups); + strideD_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups); + + strideA_host.resize(num_groups); + strideB_host.resize(num_groups); + strideC_host.resize(num_groups); + strideD_host.resize(num_groups); + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + strideA_host[group_idx] = + cute::make_int_tuple_from( + config.lda[group_idx]); + strideB_host[group_idx] = + cute::make_int_tuple_from( + config.ldb[group_idx]); + strideC_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + strideD_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + } + CUDA_CHECK(cudaMemcpy( + strideA_device.data(), + strideA_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + strideB_device.data(), + strideB_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + strideC_device.data(), + strideC_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + strideD_device.data(), + strideD_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return Status::kSuccess; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index ad994acb..33e6e9a8 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -64,7 +64,6 @@ void initialize_gemm_reference_operations_f8_f6_f32(Manifest &manifest); void initialize_block_scaled_gemm_reference_operations_fp4a_vs16(Manifest &manifest); void initialize_block_scaled_gemm_reference_operations_fp4a_vs32(Manifest &manifest); void initialize_block_scaled_gemm_reference_operations_mixed8bitsa(Manifest &manifest); - void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest); void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest); void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest); @@ -114,7 +113,6 @@ void initialize_reference_operations(Manifest &manifest) { initialize_block_scaled_gemm_reference_operations_fp4a_vs16(manifest); initialize_block_scaled_gemm_reference_operations_fp4a_vs32(manifest); initialize_block_scaled_gemm_reference_operations_mixed8bitsa(manifest); - } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/sparse_gemm_operation_3x.hpp b/tools/library/src/sparse_gemm_operation_3x.hpp index c77c2363..5c1429c4 100644 --- a/tools/library/src/sparse_gemm_operation_3x.hpp +++ b/tools/library/src/sparse_gemm_operation_3x.hpp @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/detail/collective.hpp" #include "cutlass/library/library.h" +#include "cutlass/library/util.h" #include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor #include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter #include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride @@ -45,14 +46,6 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// -#define CUDA_CHECK(cuda_error) \ - { \ - if (cuda_error != cudaSuccess) { \ - printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ - return Status::kInvalid; \ - } \ - } - namespace cutlass::library { /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index 5b39e749..525b4794 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -330,18 +330,18 @@ static struct { char const *text; char const *pretty; OperationKind enumerant; -} -OperationKind_enumerants[] = { - {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, +} OperationKind_enumerants[] = { + {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, {"gemm", "Gemm", OperationKind::kGemm}, {"block_scaled_gemm", "blockScaledGemm", OperationKind::kBlockScaledGemm}, {"rank_k", "RankK", OperationKind::kRankK}, {"rank_2k", "Rank2K", OperationKind::kRank2K}, {"trmm", "Trmm", OperationKind::kTrmm}, {"symm", "Symm", OperationKind::kSymm}, - {"conv2d", "Conv2d", OperationKind::kConv2d}, - {"conv3d", "Conv3d", OperationKind::kConv3d}, + {"conv2d", "Conv2d", OperationKind::kConv2d}, + {"conv3d", "Conv3d", OperationKind::kConv3d}, {"spgemm", "SparseGemm", OperationKind::kSparseGemm}, + {"grouped_gemm", "GroupedGemm", OperationKind::kGroupedGemm}, }; /// Converts a Status enumerant to a string @@ -504,7 +504,6 @@ NumericTypeID_enumerants[] = { {"fe2m1", "FE2M1", NumericTypeID::kFE2M1}, {"fue8m0", "FUE8M0", NumericTypeID::kFUE8M0}, {"fue4m3", "FUE4M3", NumericTypeID::kFUE4M3}, - {"f16", "F16", NumericTypeID::kF16}, {"bf16", "BF16", NumericTypeID::kBF16}, {"f32", "F32", NumericTypeID::kF32}, @@ -577,7 +576,6 @@ int sizeof_bits(NumericTypeID type) { case NumericTypeID::kFE2M1: return 4; case NumericTypeID::kFUE8M0: return 8; case NumericTypeID::kFUE4M3: return 8; - case NumericTypeID::kF16: return 16; case NumericTypeID::kBF16: return 16; case NumericTypeID::kTF32: return 32; @@ -666,7 +664,6 @@ bool is_signed_type(NumericTypeID type) { case NumericTypeID::kFE2M1: return true; case NumericTypeID::kFUE8M0: return false; case NumericTypeID::kFUE4M3: return false; - case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; @@ -707,7 +704,6 @@ bool is_float_type(NumericTypeID type) { case NumericTypeID::kFE2M1: return true; case NumericTypeID::kFUE8M0: return true; case NumericTypeID::kFUE4M3: return true; - case NumericTypeID::kF16: return true; case NumericTypeID::kBF16: return true; case NumericTypeID::kTF32: return true; @@ -1256,7 +1252,6 @@ bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string c *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; - case NumericTypeID::kFE2M3: { float tmp; @@ -1292,7 +1287,6 @@ bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string c *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; - case NumericTypeID::kF16: { float tmp; @@ -1473,7 +1467,6 @@ std::string lexical_cast(std::vector &bytes, NumericTypeID type) { ss << tmp; } break; - case NumericTypeID::kF16: { float tmp = *reinterpret_cast(bytes.data()); @@ -1652,7 +1645,6 @@ bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t sr *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; - case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); @@ -1789,7 +1781,6 @@ bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; - case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); @@ -1927,7 +1918,6 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; - case NumericTypeID::kF16: { *reinterpret_cast(bytes.data()) = static_cast(float(src)); diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 0e257413..53f13ab9 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -46,7 +46,8 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES src/problem_space.cpp src/operation_profiler.cu src/gemm_operation_profiler.cu - src/block_scaled_gemm_operation_profiler.cu + src/grouped_gemm_operation_profiler.cu + src/block_scaled_gemm_operation_profiler.cu src/rank_k_operation_profiler.cu src/rank_2k_operation_profiler.cu src/trmm_operation_profiler.cu @@ -112,6 +113,7 @@ set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_RANK_K --operation=RankK --pro set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_RANK_2K --operation=Rank2K --providers=cutlass --verification-providers=cublas --junit-output=test_cutlass_profiler_rank_2k --print-kernel-before-running=true) set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_TRMM --operation=Trmm --providers=cutlass --verification-providers=device,host --junit-output=test_cutlass_profiler_trmm --print-kernel-before-running=true) set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_SYMM --operation=Symm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_symm --print-kernel-before-running=true) +set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GROUPED_GEMM --operation=GroupedGemm --providers=cutlass --verification-providers=device --junit-output=test_cutlass_profiler_grouped_gemm --print-kernel-before-running=true) cutlass_add_executable_tests( test_profiler cutlass_profiler @@ -125,6 +127,7 @@ cutlass_add_executable_tests( RANK_2K TRMM SYMM + GROUPED_GEMM TEST_COMMAND_OPTIONS_PREFIX CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_ DISABLE_EXECUTABLE_INSTALL_RULE diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 18b18cbc..4e5693a8 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -108,6 +108,8 @@ public: bool use_pdl{false}; + bool enable_sm90_mixed_dtype_shuffle_test{false}; + // // Methods // @@ -160,6 +162,13 @@ public: /// Buffer used for the cutlass reduction operations' host workspace std::vector reduction_host_workspace; + /// For mixed input dtype kernels + DeviceAllocation *Scale{nullptr}; // Scale tensor + DeviceAllocation *Zero{nullptr}; // Zero tensor + DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + cudaStream_t stream; }; diff --git a/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h new file mode 100644 index 00000000..d1871e15 --- /dev/null +++ b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -0,0 +1,263 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief GroupedGemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "device_context.h" +#include "operation_profiler.h" +#include "options.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GroupedGemmOperationProfiler : public OperationProfiler { +public: + /// Problem structure obtained from problem space + struct GroupedGemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped}; + + std::vector problem_sizes; + std::vector> problem_sizes_3x; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + std::vector lda{0}; + std::vector ldb{0}; + std::vector ldc{0}; + + std::vector alpha; + std::vector beta; + + /// Parses the problem + Status parse( + library::GemmDescription const& operation_desc, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); }; + int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); }; + int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); }; + + /// Total number of bytes loaded + int64_t bytes(library::GemmDescription const& operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GemmDescription const& operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult& result, + library::GemmDescription const& operation_desc, + ProblemSpace const& problem_space); + }; + + // workspace contains the allocated blocks, arguments just contain the raw + // pointers + struct GroupedGemmWorkspace { + + std::vector A_ptr_array_device; + std::vector B_ptr_array_device; + std::vector C_ptr_array_device; + std::vector D_ptr_array_device; + std::vector reference_ptr_array_host; + std::vector A_ptr_array_host; + std::vector B_ptr_array_host; + std::vector C_ptr_array_host; + std::vector D_ptr_array_host; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + /// *NOT* the number of groups in the grouped GEMM + int problem_count{1}; + + DeviceAllocation* problem_sizes_array_device{nullptr}; + DeviceAllocation* problem_sizes_3x_array_device{nullptr}; + DeviceAllocation* lda_array_device{nullptr}; + DeviceAllocation* ldb_array_device{nullptr}; + DeviceAllocation* ldc_array_device{nullptr}; + DeviceAllocation* ldd_array_device{nullptr}; + + library::GemmGroupedConfiguration configuration; + library::GemmGroupedArguments arguments; + + std::vector host_workspace; + DeviceAllocation device_workspace; + }; + +private: + void init_arguments(Options const& options) { + gemm_workspace_.arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data(); + gemm_workspace_.arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data(); + gemm_workspace_.arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data(); + gemm_workspace_.arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data(); + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.lda = static_cast(gemm_workspace_.lda_array_device->data()); + gemm_workspace_.arguments.ldb = static_cast(gemm_workspace_.ldb_array_device->data()); + gemm_workspace_.arguments.ldc = static_cast(gemm_workspace_.ldc_array_device->data()); + gemm_workspace_.arguments.ldd = static_cast(gemm_workspace_.ldc_array_device->data()); + gemm_workspace_.arguments.problem_sizes = + static_cast(gemm_workspace_.problem_sizes_array_device->data()); + gemm_workspace_.arguments.problem_sizes_3x = static_cast*>( + gemm_workspace_.problem_sizes_3x_array_device->data()); + gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data(); + gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size(); + gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + gemm_workspace_.arguments.sm_count = options.device.properties[0].multiProcessorCount; + } + +protected: + /// GEMM problem obtained from problem space + GroupedGemmProblem problem_; + + /// Device memory allocations + GroupedGemmWorkspace gemm_workspace_; + +public: + GroupedGemmOperationProfiler(Options const& options); + + virtual ~GroupedGemmOperationProfiler(); + + GroupedGemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream& out) const; + + /// Prints examples + virtual void print_examples(std::ostream& out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Measures performance results + virtual bool profile( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +protected: + /// Initializes the performance result + void initialize_result_( + PerformanceResult& result, + Options const& options, + library::GemmDescription const& operation_desc, + ProblemSpace const& problem_space); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult& result, + Options const& options, + library::Operation const* operation, + void* arguments, + void* host_workspace, + void* device_workspace) override; + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const* operation, + ProblemSpace::Problem const& problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/include/cutlass/profiler/operation_profiler.h b/tools/profiler/include/cutlass/profiler/operation_profiler.h index 185e6f03..446ef2c1 100644 --- a/tools/profiler/include/cutlass/profiler/operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/operation_profiler.h @@ -241,17 +241,30 @@ protected: /// Profiles the GPU kernel launched in `func` running simultaneously on all /// requested devices. + Status profile_kernel_w_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + Status profile_kernel_( - PerformanceResult &result, - Options const &options, - const std::function &func, - const std::vector &streams); + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); /// Profiles the GPU kernel launched in `func` on the `stream` Status profile_kernel_( - PerformanceResult &result, - Options const &options, - const std::function &func, + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_no_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, cudaStream_t stream = nullptr); private: diff --git a/tools/profiler/include/cutlass/profiler/options.h b/tools/profiler/include/cutlass/profiler/options.h index 449aa70e..19c9ea6a 100644 --- a/tools/profiler/include/cutlass/profiler/options.h +++ b/tools/profiler/include/cutlass/profiler/options.h @@ -208,6 +208,8 @@ public: /// Minimum number of iterations to profile int min_iterations{10}; + bool use_cuda_graphs{false}; + /// Number of ms to sleep between profiling periods (ms) int sleep_duration{50}; diff --git a/tools/profiler/include/cutlass/profiler/problem_space.h b/tools/profiler/include/cutlass/profiler/problem_space.h index 03903e3e..9bdbec65 100644 --- a/tools/profiler/include/cutlass/profiler/problem_space.h +++ b/tools/profiler/include/cutlass/profiler/problem_space.h @@ -988,6 +988,12 @@ bool arg_as_scalar( ProblemSpace const &problem_space, ProblemSpace::Problem const &problem); +bool arg_as_string( + std::string& arg, + char const* name, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + /// Returns true if a tensor description satisfies a `tensor` value bool tensor_description_satisfies( library::TensorDescription const &tensor_desc, diff --git a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu index 81ea0522..d3b0b6bd 100644 --- a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu +++ b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu @@ -75,7 +75,6 @@ BlockScaledGemmOperationProfiler::BlockScaledGemmOperationProfiler(Options const {ArgumentTypeID::kTensor, {"D"}, "Tensor storing the D output"}, {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, - // TODO: Bring these back once SM100 future audits are complete {ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"}, {ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"}, {ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"}, @@ -113,14 +112,11 @@ void BlockScaledGemmOperationProfiler::print_examples(std::ostream &out) const { << "Schmoo over problem size and beta:\n" << " $ cutlass_profiler --operation=block_scaled_gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" - // TODO: Bring these back once SM100 future audits are complete -#if 0 - << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" + << "For column major, use column, col, or n. For row major use, row or t:\n" << " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n" << "Profile a particular problem size with split K and parallel reduction:\n" << " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n" -#endif << "Using various input value distribution:\n" << " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n" @@ -225,7 +221,6 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse( this->split_k_slices = 1; } - // TODO: Bring these back once SM100 future audits are complete if (this->split_k_mode != library::SplitKMode::kSerial) { std::cout<<"SplitK/StreamK feature is not supported yet!"; return Status::kErrorInvalidProblem; @@ -403,7 +398,6 @@ void BlockScaledGemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback); - // TODO: Bring these back once SM100 future audits are complete set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); set_argument(result, "split_k_slices", problem_space, split_k_slices); set_argument(result, "batch_count", problem_space, batch_count); @@ -536,8 +530,6 @@ bool BlockScaledGemmOperationProfiler::initialize_reduction_configuration_( library::Operation const *operation, ProblemSpace::Problem const &problem) { - // TODO: Bring these back once SM100 future audits are complete -#if 1 library::BlockScaledGemmDescription const &gemm_desc = static_cast(operation->description()); @@ -577,8 +569,6 @@ bool BlockScaledGemmOperationProfiler::initialize_reduction_configuration_( // reduction operation found and initialized return true; -#endif - return false; } /// Initializes workspace diff --git a/tools/profiler/src/cublas_helpers.cu b/tools/profiler/src/cublas_helpers.cu index 612ccfc5..b39fdd9d 100644 --- a/tools/profiler/src/cublas_helpers.cu +++ b/tools/profiler/src/cublas_helpers.cu @@ -545,6 +545,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle, cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, requestedAlgoCount, heuristicResult, &returnedResults); if (returnedResults == 0) { + cudaFree(workspaceHeuristic); return false; } @@ -589,6 +590,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle, // Handle errors if (status != CUBLAS_STATUS_SUCCESS) { std::cerr << "cublasLtMatmul AutoTuning failed with status: " << cublasLtGetStatusName(status) << std::endl; + cudaFree(workspaceHeuristic); return false; } @@ -653,6 +655,7 @@ bool cublasLtGemmExDispatcher::get_cublaslt_algo(cublasLtHandle_t handle, throw std::bad_alloc(); } + cudaFree(workspaceHeuristic); return true; } diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index efffefb7..6ecee707 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -36,16 +36,17 @@ #include // Profiler includes -#include "cutlass/profiler/cutlass_profiler.h" -#include "cutlass/profiler/gemm_operation_profiler.h" -#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h" -#include "cutlass/profiler/rank_k_operation_profiler.h" -#include "cutlass/profiler/rank_2k_operation_profiler.h" -#include "cutlass/profiler/trmm_operation_profiler.h" -#include "cutlass/profiler/symm_operation_profiler.h" +#include "cutlass/profiler/block_scaled_gemm_operation_profiler.h" #include "cutlass/profiler/conv2d_operation_profiler.h" #include "cutlass/profiler/conv3d_operation_profiler.h" +#include "cutlass/profiler/cutlass_profiler.h" +#include "cutlass/profiler/gemm_operation_profiler.h" +#include "cutlass/profiler/grouped_gemm_operation_profiler.h" +#include "cutlass/profiler/rank_2k_operation_profiler.h" +#include "cutlass/profiler/rank_k_operation_profiler.h" #include "cutlass/profiler/sparse_gemm_operation_profiler.h" +#include "cutlass/profiler/symm_operation_profiler.h" +#include "cutlass/profiler/trmm_operation_profiler.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -76,6 +77,8 @@ CutlassProfiler::CutlassProfiler( operation_profilers_.emplace_back(new TrmmOperationProfiler(options)); operation_profilers_.emplace_back(new SymmOperationProfiler(options)); + + operation_profilers_.emplace_back(new GroupedGemmOperationProfiler(options)); } CutlassProfiler::~CutlassProfiler() { @@ -201,6 +204,7 @@ void CutlassProfiler::print_usage_(std::ostream &out) { << " $ cutlass_profiler --operation=Conv3d --help\n\n" << " $ cutlass_profiler --operation=Conv2d --help\n\n" << " $ cutlass_profiler --operation=SparseGemm --help\n\n" + << " $ cutlass_profiler --operation=GroupedGemm --help\n\n" ; } diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 741a5f04..26efacdc 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -616,7 +616,6 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; - case library::NumericTypeID::kFUE4M3: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -657,7 +656,6 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; - case library::NumericTypeID::kF64: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -823,7 +821,6 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { ); break; - case library::NumericTypeID::kFE2M3: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -856,7 +853,6 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; - case library::NumericTypeID::kF16: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -1086,7 +1082,6 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) { ); break; - case library::NumericTypeID::kFE2M3: cutlass::reference::device::BlockFillSequential( reinterpret_cast(pointer_), @@ -1119,7 +1114,6 @@ void DeviceAllocation::initialize_sequential_device(Distribution dist) { static_cast(dist.sequential.start) ); break; - case library::NumericTypeID::kF16: cutlass::reference::device::BlockFillSequential( reinterpret_cast(pointer_), @@ -1360,7 +1354,6 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) { ); break; - case library::NumericTypeID::kFE2M3: cutlass::reference::host::BlockFillSequential( reinterpret_cast(host_data.data()), @@ -1393,7 +1386,6 @@ void DeviceAllocation::initialize_sequential_host(Distribution dist) { static_cast(dist.sequential.start) ); break; - case library::NumericTypeID::kF16: cutlass::reference::host::BlockFillSequential( reinterpret_cast(host_data.data()), @@ -1690,7 +1682,6 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); - case library::NumericTypeID::kFUE4M3: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -1717,7 +1708,6 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); - case library::NumericTypeID::kF16: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -1886,7 +1876,6 @@ bool DeviceAllocation::block_compare_relatively_equal( capacity, static_cast(epsilon), static_cast(nonzero_floor)); - case library::NumericTypeID::kFUE4M3: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), @@ -1925,7 +1914,6 @@ bool DeviceAllocation::block_compare_relatively_equal( capacity, static_cast(epsilon), static_cast(nonzero_floor)); - case library::NumericTypeID::kF16: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), @@ -2273,7 +2261,6 @@ void DeviceAllocation::write_tensor_csv( write_tensor_csv_static_type(out, *this); break; - case library::NumericTypeID::kFE2M3: write_tensor_csv_static_type(out, *this); break; @@ -2288,7 +2275,6 @@ void DeviceAllocation::write_tensor_csv( case library::NumericTypeID::kFUE8M0: write_tensor_csv_static_type(out, *this); break; - case library::NumericTypeID::kF16: write_tensor_csv_static_type(out, *this); break; @@ -2475,7 +2461,6 @@ void DeviceAllocation::fill_device(double val = 0.0) { case library::NumericTypeID::kFE2M1: tensor_fill(*this, static_cast(val)); break; - case library::NumericTypeID::kF16: tensor_fill(*this, static_cast(val)); @@ -2611,7 +2596,6 @@ void DeviceAllocation::fill_host(double val = 0.0) { static_cast(val) ); break; - case library::NumericTypeID::kFE4M3: cutlass::reference::host::BlockFill( diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index 0ac618fd..629770a3 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -75,6 +75,35 @@ DeviceAllocation *DeviceContext::allocate_tensor( return allocation; } +static void initialize_allocation_with_data_distribution( + Options const &options, + int seed_shift, + DeviceAllocation *allocation, + Distribution &data_distribution) { + if (options.initialization.provider == library::Provider::kReferenceDevice) { + if (data_distribution.kind == Distribution::Sequential) { + allocation->initialize_sequential_device( + data_distribution); + } + else { + allocation->initialize_random_device( + options.initialization.seed + seed_shift, + data_distribution); + } + } + else if (options.initialization.provider == library::Provider::kReferenceHost) { + if (data_distribution.kind == Distribution::Sequential) { + allocation->initialize_sequential_host( + data_distribution); + } + else { + allocation->initialize_random_host( + options.initialization.seed + seed_shift, + data_distribution); + } + } +} + /// Allocates memory of a given type, capacity (elements), and name DeviceAllocation *DeviceContext::allocate_and_initialize_tensor( Options const &options, @@ -122,7 +151,6 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor( data_distribution.set_uniform(1, 4, 0); break; - case library::NumericTypeID::kF16: data_distribution.set_uniform(-3, 3, 0); break; @@ -168,28 +196,9 @@ DeviceAllocation *DeviceContext::allocate_and_initialize_tensor( } } - if (options.initialization.provider == library::Provider::kReferenceDevice) { - if (data_distribution.kind == Distribution::Sequential) { - allocation->initialize_sequential_device( - data_distribution); - } - else { - allocation->initialize_random_device( - options.initialization.seed + seed_shift, - data_distribution); - } - } - else if (options.initialization.provider == library::Provider::kReferenceHost) { - if (data_distribution.kind == Distribution::Sequential) { - allocation->initialize_sequential_host( - data_distribution); - } - else { - allocation->initialize_random_host( - options.initialization.seed + seed_shift, - data_distribution); - } - } + initialize_allocation_with_data_distribution( + options, seed_shift, allocation, data_distribution + ); } return allocation; diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index fc2346d2..39625750 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -79,6 +79,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options): {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_a", "runtime-input-datatype::a"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, {ArgumentTypeID::kEnumerated, {"runtime_input_datatype_b", "runtime-input-datatype::b"}, "Runtime datatype (e4m3, e5m2, e3m2, e2m3, e2m1)"}, {ArgumentTypeID::kInteger, {"use_pdl", "use-pdl"}, "Use PDL (true, false)"}, + {ArgumentTypeID::kEnumerated, {"enable_sm90_mixed_dtype_shuffle_test", "enable-sm90-mixed-dtype-shuffle-test"}, "Enable SM90 mixed input data type kernel shuffle layout test (true, false)"}, {ArgumentTypeID::kInteger, {"swizzle_size", "swizzle-size"}, "Size to swizzle"}, }, { library::Provider::kCUBLAS} @@ -211,6 +212,11 @@ Status GemmOperationProfiler::GemmProblem::parse( this->use_pdl = false; } + if (!arg_as_bool(this->enable_sm90_mixed_dtype_shuffle_test, "enable_sm90_mixed_dtype_shuffle_test", problem_space, problem)) { + // default value + this->enable_sm90_mixed_dtype_shuffle_test = false; + } + if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) { // default value this->split_k_mode = library::SplitKMode::kSerial; @@ -399,6 +405,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "raster_order", problem_space, library::to_string(raster_order)); set_argument(result, "swizzle_size", problem_space, swizzle_size); set_argument(result, "use_pdl", problem_space, library::to_string(use_pdl)); + set_argument(result, "enable_sm90_mixed_dtype_shuffle_test", problem_space, library::to_string(enable_sm90_mixed_dtype_shuffle_test)); set_argument(result, "runtime_input_datatype_a", problem_space, library::to_string(runtime_input_datatype_a)); @@ -432,14 +439,26 @@ Status GemmOperationProfiler::initialize_configuration( Status status = problem_.parse(operation_desc, problem_space, problem); + // Note: this is a temporary workaround + bool is_current_operation_sm90_mixed_dtype_shuffle = (strstr(operation_desc.name, "_shfl") != NULL); + if (is_current_operation_sm90_mixed_dtype_shuffle && (problem_.enable_sm90_mixed_dtype_shuffle_test == false)) { + return Status::kErrorInvalidProblem; + } + if (status != Status::kSuccess) { return status; } - const auto device_count = options.device.devices.size(); + auto const device_count = options.device.devices.size(); gemm_workspace_.clear(); + library::NumericTypeID a_elem = library::get_real_type(operation_desc.A.element); + library::NumericTypeID b_elem = library::get_real_type(operation_desc.B.element); + int a_elem_bits = library::sizeof_bits(a_elem); + int b_elem_bits = library::sizeof_bits(b_elem); + bool is_mixed_input = (a_elem_bits != b_elem_bits); + for (size_t i = 0; i < device_count; ++i) { cudaSetDevice(options.device.device_id(i)); gemm_workspace_.emplace_back(); @@ -455,7 +474,6 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_[i].configuration.cluster_shape_fallback.m() = int(problem_.cluster_m_fallback); gemm_workspace_[i].configuration.cluster_shape_fallback.n() = int(problem_.cluster_n_fallback); gemm_workspace_[i].configuration.cluster_shape_fallback.k() = int(problem_.cluster_k_fallback); - gemm_workspace_[i].configuration.lda = problem_.lda; gemm_workspace_[i].configuration.ldb = problem_.ldb; gemm_workspace_[i].configuration.ldc = problem_.ldc; @@ -501,7 +519,77 @@ Status GemmOperationProfiler::initialize_configuration( initialize_result_(this->model_result_, options, operation_desc, problem_space); - if (const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); can_implement != Status::kSuccess) { + if (is_mixed_input) + { + const int options_g = problem_.k; + const int options_l = problem_.batch_count; + const int scale_k = (problem_.k + options_g - 1) / options_g; + // We cannot get the mainloop's ElementScale and ElementZero here, + // use the wide type to allocate a large enough workspace for S and Z. + library::NumericTypeID wide_dtype; + size_t SZ_mat_size = 0; + if (a_elem_bits > b_elem_bits) { + wide_dtype = a_elem; + SZ_mat_size = static_cast(problem_.n * scale_k); + } + else { + wide_dtype = b_elem; + SZ_mat_size = static_cast(problem_.m * scale_k); + } + + gemm_workspace_[i].Scale = device_context.allocate_tensor( + options, + "Scale", + wide_dtype, + library::LayoutTypeID::kRowMajor, + {int(SZ_mat_size), int(options_l)}, + {int(options_l)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + gemm_workspace_[i].Zero = device_context.allocate_tensor( + options, + "Zero", + wide_dtype, + library::LayoutTypeID::kRowMajor, + {int(SZ_mat_size), int(options_l)}, + {int(options_l)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + + // Packed scale is for int4 * fp8, where the original scale is fp8, and + // each scale element will be packed into an Array which is 64-bit + gemm_workspace_[i].packed_Scale = device_context.allocate_tensor( + options, + "packed-Scale", + library::NumericTypeID::kU64, + library::LayoutTypeID::kRowMajor, + {int(SZ_mat_size), int(options_l)}, + {int(options_l)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + + gemm_workspace_[i].arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; + gemm_workspace_[i].arguments.batch_count = problem_.batch_count; + + // Here is the first touch of the arguments, mark the mixed dtype, + // populate the scale and zero tensors in the following can_implement() call later. + // A and B are not populated at this moment, so do not update the dequantized A or B + gemm_workspace_[i].arguments.is_mixed_dtype = true; + gemm_workspace_[i].arguments.wider_operand = (a_elem_bits > b_elem_bits) ? cutlass::library::Sm90MixedInputWiderOperand::A : cutlass::library::Sm90MixedInputWiderOperand::B; + gemm_workspace_[i].arguments.generate_scale_and_zero = true; + gemm_workspace_[i].arguments.generate_dequantized_AB = false; + gemm_workspace_[i].arguments.dequantized_AB_ready = (bool *) malloc(sizeof(bool)); + gemm_workspace_[i].arguments.dequantized_AB_ready[0] = false; + gemm_workspace_[i].arguments.Scale = gemm_workspace_[i].Scale->data(); + gemm_workspace_[i].arguments.Zero = gemm_workspace_[i].Zero->data(); + gemm_workspace_[i].arguments.packed_Scale = gemm_workspace_[i].packed_Scale->data(); + } // End of "if (is_mixed_input)" + + const auto can_implement = operation->can_implement(&gemm_workspace_[i].configuration, &gemm_workspace_[i].arguments); + if (can_implement != Status::kSuccess) { return can_implement; } } @@ -693,6 +781,56 @@ Status GemmOperationProfiler::initialize_workspace( problem_.batch_count * gemm_workspace_[i].problem_count, i // device_index ); + + if (gemm_workspace_[i].arguments.is_mixed_dtype) { + // Dequantized tensor has the same shape of the narrow data type tensor, + // and the same data type as the wide data type tensor + // Encoded tensor has the same shape and data type of the narrow data type tensor + if (gemm_workspace_[i].arguments.wider_operand == cutlass::library::Sm90MixedInputWiderOperand::A) { + gemm_workspace_[i].dequantized_AB = device_context.allocate_tensor( + options, + "dequantized-B", + operation_desc.A.element, + operation_desc.B.layout, + {int(problem_.k), int(problem_.n)}, + {int(problem_.ldb)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + gemm_workspace_[i].encoded_AB = device_context.allocate_tensor( + options, + "encoded-B", + operation_desc.B.element, + operation_desc.B.layout, + {int(problem_.k), int(problem_.n)}, + {int(problem_.ldb)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + } + else { + gemm_workspace_[i].dequantized_AB = device_context.allocate_tensor( + options, + "dequantized-A", + operation_desc.B.element, + operation_desc.A.layout, + {int(problem_.m), int(problem_.k)}, + {int(problem_.lda)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + gemm_workspace_[i].encoded_AB = device_context.allocate_tensor( + options, + "encoded-A", + operation_desc.A.element, + operation_desc.A.layout, + {int(problem_.m), int(problem_.k)}, + {int(problem_.lda)}, + problem_.batch_count * gemm_workspace_[i].problem_count, + i // device_index + ); + } + } } if (options.execution_mode != ExecutionMode::kDryRun) { @@ -712,7 +850,7 @@ Status GemmOperationProfiler::initialize_workspace( gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); /* Query device SM count to pass onto the kernel as an argument, where needed */ - gemm_workspace_[i].arguments.sm_count = options.device.properties[0].multiProcessorCount; + gemm_workspace_[i].arguments.sm_count = options.device.properties[i].multiProcessorCount; gemm_workspace_[i].arguments.device_index = static_cast(i); } } @@ -836,6 +974,17 @@ bool GemmOperationProfiler::verify_cutlass( gemm_workspace_[i].arguments.batch_stride_C = gemm_workspace_[i].C->batch_stride(); gemm_workspace_[i].arguments.batch_stride_D = gemm_workspace_[i].Computed->batch_stride(); + if (gemm_workspace_[i].arguments.is_mixed_dtype) { + // Scale and zero already generated in initialize_configuration(), + // A and B already generated in initialize_workspace(), signal + // GemmUniversal3xOperation::update_arguments_() (trigger by underlying_operation->run()) + // to generate the dequantized matrix for verification + gemm_workspace_[i].arguments.generate_scale_and_zero = false; + gemm_workspace_[i].arguments.generate_dequantized_AB = true; + gemm_workspace_[i].arguments.dequantized_AB = gemm_workspace_[i].dequantized_AB->data(); + gemm_workspace_[i].arguments.encoded_AB = gemm_workspace_[i].encoded_AB->data(); + } + if (problem_.split_k_mode == library::SplitKMode::kParallel) { gemm_workspace_[i].arguments.D = gemm_workspace_[i].device_workspace.data(); gemm_workspace_[i].arguments.alpha = problem_.alpha_one.data(); @@ -1133,7 +1282,6 @@ bool GemmOperationProfiler::verify_with_reference_( // // Initialize state // - for (auto provider : options.verification.providers) { // Skip providers that are not enabled @@ -1149,6 +1297,21 @@ bool GemmOperationProfiler::verify_with_reference_( void *ptr_C = gemm_workspace_[i].C->data(); void *ptr_D = gemm_workspace_[i].Reference->data(); + cutlass::library::NumericTypeID element_A_for_reference = element_A; + cutlass::library::NumericTypeID element_B_for_reference = element_B; + if (gemm_workspace_[i].arguments.is_mixed_dtype && gemm_workspace_[i].arguments.dequantized_AB_ready[0]) { + // Dequantized tensor has the same shape of the narrow data type tensor, + // and the same data type as the wide data type tensor + if (gemm_workspace_[i].arguments.wider_operand == cutlass::library::Sm90MixedInputWiderOperand::A) { + ptr_B = gemm_workspace_[i].dequantized_AB->data(); + element_B_for_reference = element_A; + } + else { + ptr_A = gemm_workspace_[i].dequantized_AB->data(); + element_A_for_reference = element_B; + } + } + // To support the host-side reference, conditionally allocate and // copy tensors to host memory. std::vector host_data_A; @@ -1200,13 +1363,13 @@ bool GemmOperationProfiler::verify_with_reference_( problem_.alpha.data(), - element_A, + element_A_for_reference, gemm_desc.A.layout, gemm_desc.transform_A, ptr_A, int(gemm_workspace_[i].configuration.lda), - element_B, + element_B_for_reference, gemm_desc.B.layout, gemm_desc.transform_B, ptr_B, @@ -1349,6 +1512,13 @@ Status GemmOperationProfiler::profile_cutlass_( gemm_workspace_[dev_id].arguments.C = gemm_workspace_[dev_id].C->batch_data(problem_idx); gemm_workspace_[dev_id].arguments.D = gemm_workspace_[dev_id].Computed->batch_data(problem_idx); + if (gemm_workspace_[dev_id].arguments.is_mixed_dtype) { + // Scale, zero, and dequantized tensors are already generated in + // verify_cutlass(), no need to re-generate them in profiling + gemm_workspace_[dev_id].arguments.generate_scale_and_zero = false; + gemm_workspace_[dev_id].arguments.generate_dequantized_AB = false; + } + if (problem_.split_k_mode == library::SplitKMode::kParallel) { gemm_workspace_[dev_id].arguments.D = gemm_workspace_[dev_id].device_workspace.data(); @@ -1383,11 +1553,6 @@ Status GemmOperationProfiler::profile_cutlass_( return Status::kSuccess; }; - if (options.device.devices.size() == 1) { - auto func = [&](cudaStream_t stream, int iteration) { return launch_gemm(0, stream, iteration); }; - return profile_kernel_(result, options, func, gemm_workspace_[0].stream); - } - std::vector streams(gemm_workspace_.size()); for (size_t i = 0; i < streams.size(); i++) { streams[i] = gemm_workspace_[i].stream; diff --git a/tools/profiler/src/grouped_gemm_operation_profiler.cu b/tools/profiler/src/grouped_gemm_operation_profiler.cu new file mode 100644 index 00000000..56928a9e --- /dev/null +++ b/tools/profiler/src/grouped_gemm_operation_profiler.cu @@ -0,0 +1,1034 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#include +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/profiler/grouped_gemm_operation_profiler.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/library.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace { +std::vector> parseProblemSizes(std::string const& input) { + // input must be of the form: + // `[m0xn0xk0][m1xn1xk1]` where 0, 1 are the group indexes + std::stringstream ss(input); + std::string token; + std::vector> result; + while (std::getline(ss, token, ']')) { + std::stringstream ss(token); + std::string token; + ss.get(); // discard '[' + std::getline(ss, token, 'x'); + auto m = std::stoi(token); + std::getline(ss, token, 'x'); + auto n = std::stoi(token); + std::getline(ss, token); + auto k = std::stoi(token); + result.push_back({m, n, k}); + } + return result; +} +} // namespace + +namespace cutlass { +namespace profiler { + +GroupedGemmOperationProfiler::GroupedGemmOperationProfiler(Options const& options) + : OperationProfiler( + options, + library::OperationKind::kGroupedGemm, + {{ArgumentTypeID::kEnumerated, + {"gemm_kind"}, + "Variant of GEMM (universal, gemm, planar_complex, planar_complex_array)"}, + {ArgumentTypeID::kInteger, + {"m", "problem-size::m"}, + "M dimension of the GEMM problem space (for all groups)"}, + {ArgumentTypeID::kInteger, + {"n", "problem-size::n"}, + "N dimension of the GEMM problem space (for all groups)"}, + {ArgumentTypeID::kInteger, + {"k", "problem-size::k"}, + "K dimension of the GEMM problem space (for all groups)"}, + {ArgumentTypeID::kInteger, + {"num_groups"}, + "If m,n,k are specified, run a grouped GEMM with this number of groups, where each GEMM " + "uses the same m,n,k values."}, + {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, + {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, + {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, + {ArgumentTypeID::kTensor, {"D"}, "Tensor storing the D output"}, + {ArgumentTypeID::kScalar, + {"alpha", "epilogue::alpha"}, + "Epilogue scalar alpha (applied to all GEMMs in group)."}, + {ArgumentTypeID::kScalar, + {"beta", "epilogue::beta"}, + "Epilogue scalar beta (applied to all GEMMs in group)."}, + {ArgumentTypeID::kScalar, + {"problem-sizes"}, + "MxNxK Problem sizes for the grouped GEMM, where a group is enclosed by `[]`. E.g. " + "--problem-sizes='[m1xn1xk1][m2xn2xk2]'"}, + {ArgumentTypeID::kScalar, + {"problem-sizes-file"}, + "File containing grouped GEMM problem sizes, where each line represents a group whose " + "GEMM dimensions are 'mxnxk'."}}, + {library::Provider::kReferenceDevice}) { + + description_ = " Grouped matrix-matrix product. D[g] = alpha[g] * A[g] * B[g] + beta[g] * " + "C[g] for g in [0, num_groups)"; +} + +GroupedGemmOperationProfiler::~GroupedGemmOperationProfiler() {} + +void GroupedGemmOperationProfiler::print_usage(std::ostream& out) const { + OperationProfiler::print_usage(out); +} + +void GroupedGemmOperationProfiler::print_examples(std::ostream& out) const { + + out + << "\nExamples:\n\n" + << "Profile a particular problem size (explicit shapes):\n" + << " $ cutlass_profiler --operation=GroupedGemm --problem-sizes='[1024x1024x128][16x8x8]'\n\n" + + << "Profile a particular problem size (same M, N, K for all groups):\n" + << " $ cutlass_profiler --operation=GroupedGemm --m=16 --n=32 --k=64 --num_groups=8'\n\n" + + << "Profile a particular problem size from a file:\n" + << " $ cutlass_profiler --operation=GroupedGemm --problem-sizes-file=shapes.txt\n\n" + + << "Schmoo over problem size and beta:\n" + << " $ cutlass_profiler --operation=GroupedGemm --problem-sizes='[8x8x8],[16x8x16][32x32x32]' " + "--beta=0,1,2.5\n\n" + + << "Schmoo over accumulator types:\n" + << " $ cutlass_profiler --operation=GroupedGemm --accumulator-type=f16,f32\n\n" + + << "Run when A is f16 with column-major and B is any datatype with row-major (For column " + "major, use column, col, or n. For row major use, row or t):\n" + << " $ cutlass_profiler --operation=GroupedGemm --A=f16:column --B=*:row\n\n" + + << "Using various input value distribution:\n" + << " $ cutlass_profiler --operation=GroupedGemm --dist=uniform,min:0,max:3\n" + << " $ cutlass_profiler --operation=GroupedGemm --dist=gaussian,mean:0,stddev:3\n" + << " $ cutlass_profiler --operation=GroupedGemm --dist=sequential,start:0,delta:1\n\n" + + << "Test your changes to gemm kernels with a quick functional test and save results in " + "functional-test.csv:\n" + << " $ cutlass_profiler --operation=Gemm \\ \n" + << " --problem-sizes='[8x8x8][5x10x5],[16x8x16][32x32x32]' \\ \n" + << " --beta=0,1,2 --profiling-iterations=1 \\ \n" + << " --providers=cutlass --output=functional-test.csv\n\n"; +} + +Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse( + library::GemmDescription const& operation_desc, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem) { + + this->mode = library::GemmUniversalMode::kGrouped; + + std::bitset<3> args_exist; + std::string problem_sizes_str; + args_exist[0] = arg_as_string(problem_sizes_str, "problem-sizes", problem_space, problem); + int m, n, k; + args_exist[1] = arg_as_int(m, "m", problem_space, problem) && + arg_as_int(n, "n", problem_space, problem) && + arg_as_int(k, "k", problem_space, problem); + std::string problem_file; + args_exist[2] = arg_as_string(problem_file, "problem-sizes-file", problem_space, problem); + + if (args_exist.count() == 0) { + int num_groups = 8; + problem_sizes.resize(num_groups); + problem_sizes_3x.resize(num_groups); + int m0 = 16; + int n0 = 32; + int k0 = 64; + for (int i = 0; i < num_groups; i++) { + auto m = m0 * (i + 1); + auto n = n0 * (i + 1); + auto k = k0 * (i + 1); + problem_sizes[i] = {m, n, k}; + problem_sizes_3x[i] = {m, n, k}; + } + } + else if (args_exist.count() > 1) { + std::cerr + << "Exactly one of --problem-sizes, --problem-sizes-file, or --m --n --k may be specified.\n"; + return Status::kErrorInvalidProblem; + } + // --problem-sizes path + else if (args_exist[0]) { + auto problems = parseProblemSizes(problem_sizes_str); + auto num_groups = problems.size(); + problem_sizes.resize(num_groups); + problem_sizes_3x.resize(num_groups); + for (size_t i = 0; i < num_groups; i++) { + auto m = problems[i][0]; + auto n = problems[i][1]; + auto k = problems[i][2]; + problem_sizes[i] = {m, n, k}; + problem_sizes_3x[i] = {m, n, k}; + } + } + // m, n, k path + else if (args_exist[1]) { + int num_groups; + if (!arg_as_int(num_groups, "num_groups", problem_space, problem)) { + std::cerr << "num_groups must be specified if --m --n and --k are set.\n"; + return Status::kErrorInvalidProblem; + } + problem_sizes.resize(num_groups); + problem_sizes_3x.resize(num_groups); + for (int i = 0; i < num_groups; i++) { + problem_sizes[i] = {m, n, k}; + problem_sizes_3x[i] = {m, n, k}; + } + } + // --problem-sizes-file path + else if (args_exist[2]) { + std::ifstream file(problem_file); + if (!file.good()) { + throw std::runtime_error("Failed to open file: " + problem_file); + } + + for (std::string line; std::getline(file, line);) { + std::istringstream iss(line); + + int m, n, k; + char sep1, sep2; + std::string remaining; + + if (iss >> m >> sep1 >> n >> sep2 >> k && sep1 == 'x' && sep2 == 'x' && !(iss >> remaining)) { + problem_sizes.emplace_back(m, n, k); + problem_sizes_3x.emplace_back(m, n, k); + } else { + throw std::runtime_error( + "Invalid format in line: " + line + ". Each line in file expected to be 'mxnxk'."); + } + } + } + + if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) { + // default value + this->cluster_m = 1; + } + + if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) { + // default value + this->cluster_n = 1; + } + + if (!arg_as_int(this->cluster_k, "cluster_k", problem_space, problem)) { + // default value + this->cluster_k = 1; + } + + if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { + // default value + this->cluster_m_fallback = 0; + } + + if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { + // default value + this->cluster_n_fallback = 0; + } + + if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) { + // default value + this->cluster_k_fallback = 0; + } + + this->mode = library::GemmUniversalMode::kGrouped; + + if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.D, "D", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!arg_as_scalar( + this->alpha, + operation_desc.element_epilogue, + "alpha", + problem_space, + problem)) { + + if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { + return Status::kErrorInternal; + } + } + + if (!arg_as_scalar(this->beta, operation_desc.element_epilogue, "beta", problem_space, problem)) { + + if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { + return Status::kErrorInternal; + } + } + + auto num_groups = problem_sizes.size(); + this->lda.resize(num_groups); + this->ldb.resize(num_groups); + this->ldc.resize(num_groups); + for (size_t group_idx = 0; group_idx < num_groups; group_idx++) { + this->lda[group_idx] = DeviceAllocation::get_packed_layout( + operation_desc.A.layout, + {int(this->m(group_idx)), int(this->k(group_idx))}) + .front(); + + this->ldb[group_idx] = DeviceAllocation::get_packed_layout( + operation_desc.B.layout, + {int(this->k(group_idx)), int(this->n(group_idx))}) + .front(); + + this->ldc[group_idx] = DeviceAllocation::get_packed_layout( + operation_desc.C.layout, + {int(this->m(group_idx)), int(this->n(group_idx))}) + .front(); + } + + return Status::kSuccess; +} + +/// Total number of bytes loaded +int64_t GroupedGemmOperationProfiler::GroupedGemmProblem::bytes( + library::GemmDescription const& operation_desc) const { + // Input bytes read and Output bytes written for the gemm problem + int64_t bytes = 0; + for (size_t group_idx = 0, num_groups = problem_sizes.size(); group_idx < num_groups; + group_idx++) { + + bytes += + int64_t(library::sizeof_bits(operation_desc.A.element) * m(group_idx) / 8) * k(group_idx) + + int64_t(library::sizeof_bits(operation_desc.B.element) * n(group_idx) / 8) * k(group_idx) + + int64_t(library::sizeof_bits(operation_desc.C.element) * m(group_idx) / 8) * n(group_idx); + + // Set is_beta_zero true if beta is zero + bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i == 0; }); + // Output bytes read for the gemm problem for non-zero beta values + if (!is_beta_zero) { + bytes += + int64_t(library::sizeof_bits(operation_desc.C.element) * m(group_idx) / 8) * n(group_idx); + } + } + + return bytes; +} + +/// Total number of flops computed +int64_t GroupedGemmOperationProfiler::GroupedGemmProblem::flops( + library::GemmDescription const& operation_desc) const { + int64_t flops_ = 0; + for (size_t group_idx = 0, num_groups = problem_sizes.size(); group_idx < num_groups; + group_idx++) { + flops_ += + (int64_t(m(group_idx)) * n(group_idx) * k(group_idx) + m(group_idx) * n(group_idx)) * 2; + } + + // complex-valued support + switch (operation_desc.tile_description.math_instruction.math_operation) { + case library::MathOperationID::kMultiplyAddComplex: + case library::MathOperationID::kMultiplyAddComplexFastF32: + flops_ *= 4; + break; + case library::MathOperationID::kMultiplyAddGaussianComplex: + flops_ *= 3; + break; + + default: + break; + } + + return flops_; +} + +/// Initializes a performance result +void GroupedGemmOperationProfiler::GroupedGemmProblem::initialize_result( + PerformanceResult& result, + library::GemmDescription const& operation_desc, + ProblemSpace const& problem_space) { + + result.arguments.resize(problem_space.rank()); + + set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); + + set_argument( + result, + "A", + problem_space, + std::string(library::to_string(operation_desc.A.element)) + ":" + + library::to_string(operation_desc.A.layout)); + + set_argument( + result, + "B", + problem_space, + std::string(library::to_string(operation_desc.B.element)) + ":" + + library::to_string(operation_desc.B.layout)); + + set_argument( + result, + "C", + problem_space, + std::string(library::to_string(operation_desc.C.element)) + ":" + + library::to_string(operation_desc.C.layout)); + + set_argument( + result, + "D", + problem_space, + std::string(library::to_string(operation_desc.D.element)) + ":" + + library::to_string(operation_desc.D.layout)); + + { + std::stringstream ss; + ss << "'"; + for (auto const& problem_size : problem_sizes) { + ss << "["; + auto m = problem_size[0]; + auto n = problem_size[1]; + auto k = problem_size[2]; + ss << m << "x" << n << "x" << k; + ss << "]"; + } + ss << "'"; + set_argument(result, "problem-sizes", problem_space, ss.str()); + } + + set_argument(result, "cluster_m", problem_space, cluster_m); + set_argument(result, "cluster_n", problem_space, cluster_n); + set_argument(result, "cluster_k", problem_space, cluster_k); + set_argument(result, "cluster_m_fallback", problem_space, cluster_m_fallback); + set_argument(result, "cluster_n_fallback", problem_space, cluster_n_fallback); + set_argument(result, "cluster_k_fallback", problem_space, cluster_k_fallback); + + set_argument( + result, + "alpha", + problem_space, + library::lexical_cast(alpha, operation_desc.element_epilogue)); + + set_argument( + result, + "beta", + problem_space, + library::lexical_cast(beta, operation_desc.element_epilogue)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Extracts the problem dimensions +Status GroupedGemmOperationProfiler::initialize_configuration( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem) { + + library::GemmDescription const& operation_desc = + static_cast(operation->description()); + + if (operation_desc.gemm_kind != library::GemmKind::kGrouped) { + return Status::kErrorInvalidProblem; + } + + Status status = problem_.parse(operation_desc, problem_space, problem); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = problem_.problem_sizes.size(); + gemm_workspace_.configuration.problem_count = num_groups; + gemm_workspace_.configuration.lda = problem_.lda.data(); + gemm_workspace_.configuration.ldb = problem_.ldb.data(); + gemm_workspace_.configuration.ldc = problem_.ldc.data(); + + initialize_result_(this->model_result_, options, operation_desc, problem_space); + + return status; +} + +/// Initializes the performance result +void GroupedGemmOperationProfiler::initialize_result_( + PerformanceResult& result, + Options const& options, + library::GemmDescription const& operation_desc, + ProblemSpace const& problem_space) { + + result.provider = library::Provider::kCUTLASS; + result.disposition = Disposition::kNotRun; + result.status = Status::kSuccess; + result.operation_name = operation_desc.name; + + problem_.initialize_result(result, operation_desc, problem_space); + + OperationProfiler::initialize_result_(result, operation_desc, problem_space); + + result.bytes = problem_.bytes(operation_desc); + result.flops = problem_.flops(operation_desc); + result.runtime = 0; + +} + +/// Initializes workspace +Status GroupedGemmOperationProfiler::initialize_workspace( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem) { + + if (options.device.devices.size() != 1) { + throw std::runtime_error("This operation profiler only supports a single " + "device."); + } + + cudaError_t result; + result = cudaSetDevice(options.device.device_id(0)); + if (result != cudaSuccess) { + throw std::runtime_error("cudaSetDevice() failed."); + } + + library::Operation const* underlying_operation = operation; + library::GemmDescription const& operation_desc = + static_cast(operation->description()); + + // Compute the number of copies of the problem to avoid L2 camping. + if (!options.profiling.workspace_count) { + int64_t bytes = problem_.bytes(operation_desc); + if (bytes < 3 * int64_t(options.device.properties[0].l2CacheSize)) { + gemm_workspace_.problem_count = + 1 + int((3 * int64_t(options.device.properties[0].l2CacheSize)) / bytes); + } + else { + gemm_workspace_.problem_count = 1; + } + } + else { + gemm_workspace_.problem_count = options.profiling.workspace_count; + } + + bool allocate_device_tensors = options.execution_mode != ExecutionMode::kDryRun; + if (allocate_device_tensors) { + size_t num_groups = problem_.problem_sizes.size(); + // input data + gemm_workspace_.A_ptr_array_host.resize(num_groups); + gemm_workspace_.B_ptr_array_host.resize(num_groups); + gemm_workspace_.C_ptr_array_host.resize(num_groups); + gemm_workspace_.D_ptr_array_host.resize(num_groups); + static_assert(sizeof(void*) == 8); // allocating blocks for pointers, so verify pointer size + // ldx + gemm_workspace_.lda_array_device = + device_context + .allocate_block(options, "lda_array", library::NumericTypeID::kS64, num_groups, 0); + gemm_workspace_.ldb_array_device = + device_context + .allocate_block(options, "ldb_array", library::NumericTypeID::kS64, num_groups, 0); + gemm_workspace_.ldc_array_device = + device_context + .allocate_block(options, "ldc_array", library::NumericTypeID::kS64, num_groups, 0); + gemm_workspace_.lda_array_device->copy_from_host(problem_.lda.data()); + gemm_workspace_.ldb_array_device->copy_from_host(problem_.ldb.data()); + gemm_workspace_.ldc_array_device->copy_from_host(problem_.ldc.data()); + // problem sizes + gemm_workspace_.problem_sizes_array_device = device_context.allocate_block( + options, + "problem_sizes_array", + library::NumericTypeID::kU8, + num_groups * sizeof(gemm::GemmCoord), + 0); + gemm_workspace_.problem_sizes_array_device->copy_from_host(problem_.problem_sizes.data()); + + gemm_workspace_.problem_sizes_3x_array_device = device_context.allocate_block( + options, + "problem_sizes_array_3x", + library::NumericTypeID::kU8, + num_groups * sizeof(cute::Shape), + 0); + gemm_workspace_.problem_sizes_3x_array_device->copy_from_host(problem_.problem_sizes_3x.data()); + + // reference + gemm_workspace_.reference_ptr_array_host.resize(num_groups); + + int seed_shift = 0; + for (size_t group_idx = 0; group_idx < num_groups; group_idx++) { + auto group_str = std::to_string(group_idx); + gemm_workspace_.A_ptr_array_host[group_idx] = device_context.allocate_and_initialize_tensor( + options, + "A_" + group_str, + operation_desc.A.element, + operation_desc.A.layout, + {int(problem_.m(group_idx)), int(problem_.k(group_idx))}, + {int(problem_.lda[group_idx])}, + gemm_workspace_.problem_count, + seed_shift++, + 0); + gemm_workspace_.B_ptr_array_host[group_idx] = device_context.allocate_and_initialize_tensor( + options, + "B_" + group_str, + operation_desc.B.element, + operation_desc.B.layout, + {int(problem_.k(group_idx)), int(problem_.n(group_idx))}, + {int(problem_.ldb[group_idx])}, + gemm_workspace_.problem_count, + seed_shift++, + 0); + gemm_workspace_.C_ptr_array_host[group_idx] = device_context.allocate_and_initialize_tensor( + options, + "C_" + group_str, + operation_desc.C.element, + operation_desc.C.layout, + {int(problem_.m(group_idx)), int(problem_.n(group_idx))}, + {int(problem_.ldc[group_idx])}, + gemm_workspace_.problem_count, + seed_shift++, + 0); + gemm_workspace_.D_ptr_array_host[group_idx] = device_context.allocate_tensor( + options, + "D_" + group_str, + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m(group_idx)), int(problem_.n(group_idx))}, + {int(problem_.ldc[group_idx])}, + gemm_workspace_.problem_count, + 0); + + gemm_workspace_.reference_ptr_array_host[group_idx] = device_context.allocate_tensor( + options, + "Reference_" + group_str, + operation_desc.D.element, + operation_desc.D.layout, + {int(problem_.m(group_idx)), int(problem_.n(group_idx))}, + {int(problem_.ldc[group_idx])}, + gemm_workspace_.problem_count, + 0); + } + + // takes the allocated tensors and initializes an array of pointers per problem in the workspace + auto create_dev_ptr_array_all_workspace = [&]( + std::vector& dev_ptr_arrays, + std::vector const& input, + std::string const& id) { + auto num_workspaces = gemm_workspace_.problem_count; + dev_ptr_arrays.resize(num_workspaces); + // note "problem_count" here refers to input/output count for L2 cycling + for (int i = 0; i < gemm_workspace_.problem_count; i++) { + std::string name = id + "_ptr_array_workspace" + std::to_string(i); + dev_ptr_arrays[i] = + device_context.allocate_block(options, name, library::NumericTypeID::kU64, num_groups, 0); + std::vector group_ptrs(num_groups); + for (size_t group_idx = 0; group_idx < num_groups; group_idx++) { + group_ptrs[group_idx] = input[group_idx]->batch_data(i); + } + dev_ptr_arrays[i]->copy_from_host(group_ptrs.data()); + } + }; + create_dev_ptr_array_all_workspace( + gemm_workspace_.A_ptr_array_device, + gemm_workspace_.A_ptr_array_host, + "A"); + create_dev_ptr_array_all_workspace( + gemm_workspace_.B_ptr_array_device, + gemm_workspace_.B_ptr_array_host, + "B"); + create_dev_ptr_array_all_workspace( + gemm_workspace_.C_ptr_array_device, + gemm_workspace_.C_ptr_array_host, + "C"); + create_dev_ptr_array_all_workspace( + gemm_workspace_.D_ptr_array_device, + gemm_workspace_.D_ptr_array_host, + "D"); + } + init_arguments(options); + + // + // Initialize the CUTLASS operation + // + Status status = Status::kSuccess; + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + if (options.execution_mode != ExecutionMode::kDryRun) { + uint64_t workspace_size = + underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); + gemm_workspace_.host_workspace.resize(workspace_size, 0); + + workspace_size = underlying_operation->get_device_workspace_size( + &gemm_workspace_.configuration, + &gemm_workspace_.arguments); + gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + + status = underlying_operation->initialize( + &gemm_workspace_.configuration, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data()); + if (status != Status::kSuccess) { + return status; + } + + status = underlying_operation->can_implement( + &gemm_workspace_.configuration, + &gemm_workspace_.arguments); + if (status != Status::kSuccess) { + return status; + } + } + + // + // If CUTLASS is enabled, generate a result for it + // + results_.push_back(model_result_); + results_.back().provider = library::Provider::kCUTLASS; + results_.back().op_kind = library::OperationKind::kGroupedGemm; + results_.back().disposition = Disposition::kNotRun; + + for (auto provider : verification_providers_) { + results_.back().verification_map[provider] = Disposition::kNotRun; + } + } + return status; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Verifies CUTLASS against references +bool GroupedGemmOperationProfiler::verify_cutlass( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem) { + + if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + return true; + } + + if (options.execution_mode == ExecutionMode::kDryRun) { + return true; + } + + init_arguments(options); + + library::Operation const* underlying_operation = operation; + results_.back().status = underlying_operation->run( + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data()); + + if (results_.back().status != Status::kSuccess) { + results_.back().disposition = Disposition::kFailed; + throw "failed"; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + + // CUTLASS op ran the but not yet verified against any verification provider + results_.back().disposition = Disposition::kNotVerified; + + // + // Run verification providers + // + + if (options.verification.enabled) { + +#if CUTLASS_ENABLE_CUBLAS + if (options.verification.provider_enabled(library::Provider::kCUBLAS)) { + // set verification map for cublas to not supported + results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kNotSupported; + } +#endif // #if CUTLASS_ENABLE_CUBLAS + + library::GemmDescription const& gemm_desc = + static_cast(operation->description()); + + bool verification_status = verify_with_reference_( + options, + report, + device_context, + operation, + problem_space, + problem, + gemm_desc.A.element, + gemm_desc.B.element); + + // Update disposition to worst case verification outcome among all + // verification providers which are supported + bool is_any_verification_run_passed = false; + for (auto& m : results_.back().verification_map) { + if (m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { + results_.back().disposition = m.second; + return true; + } + if (!is_any_verification_run_passed && m.second == Disposition::kPassed) { + is_any_verification_run_passed = true; + } + } + + if (is_any_verification_run_passed) { + results_.back().disposition = Disposition::kPassed; + } + } + + // if verification.required is set, then return success iff at least one ref-check was run + if (options.verification.required) { + bool did_any_verification_run = false; + for (auto provider : options.verification.providers) { + did_any_verification_run |= + (Disposition::kNotRun != results_.back().verification_map[provider]); + } + + if (not did_any_verification_run) { + results_.back().status = Status::kErrorNotSupported; + return false; + } + } + + // Return true means continue profiling + return true; +} + +/// Verifies CUTLASS against host and device references +bool GroupedGemmOperationProfiler::verify_with_reference_( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B) { + library::GemmDescription const& gemm_desc = + static_cast(operation->description()); + + for (auto provider : options.verification.providers) { + + // Skip providers that are not enabled + if (!options.verification.provider_enabled(provider)) { + continue; + } + + auto status = Status::kSuccess; + auto disposition = Disposition::kFailed; + for (size_t group_idx = 0, num_groups = problem_.problem_sizes.size(); group_idx < num_groups; + group_idx++) { + void* ptr_A = gemm_workspace_.A_ptr_array_host[group_idx]->data(); + void* ptr_B = gemm_workspace_.B_ptr_array_host[group_idx]->data(); + void* ptr_C = gemm_workspace_.C_ptr_array_host[group_idx]->data(); + void* ptr_D = gemm_workspace_.reference_ptr_array_host[group_idx]->data(); + + // To support the host-side reference, conditionally allocate and + // copy tensors to host memory. + std::vector host_data_A; + std::vector host_data_B; + std::vector host_data_C; + std::vector host_data_D; + + if (provider == library::Provider::kReferenceHost) { + host_data_A.resize(gemm_workspace_.A_ptr_array_host[group_idx]->bytes()); + ptr_A = host_data_A.data(); + gemm_workspace_.A_ptr_array_host[group_idx]->copy_to_host( + ptr_A); // this is copying all the data for L2 busting as well + + host_data_B.resize(gemm_workspace_.B_ptr_array_host[group_idx]->bytes()); + ptr_B = host_data_B.data(); + gemm_workspace_.B_ptr_array_host[group_idx]->copy_to_host(ptr_B); + + host_data_C.resize(gemm_workspace_.C_ptr_array_host[group_idx]->bytes()); + ptr_C = host_data_C.data(); + gemm_workspace_.C_ptr_array_host[group_idx]->copy_to_host(ptr_C); + + host_data_D.resize(gemm_workspace_.reference_ptr_array_host[group_idx]->bytes()); + ptr_D = host_data_D.data(); + } + + library::Handle handle; + handle.set_provider(provider); + + status = handle.gemm_universal( + library::GemmUniversalMode::kGemm, + problem_.m(group_idx), + problem_.n(group_idx), + problem_.k(group_idx), + problem_.cluster_m, + problem_.cluster_n, + problem_.cluster_k, + problem_.cluster_m_fallback, + problem_.cluster_n_fallback, + problem_.cluster_k_fallback, + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + problem_.alpha.data(), + element_A, + gemm_desc.A.layout, + gemm_desc.transform_A, + ptr_A, + int(problem_.lda[group_idx]), + element_B, + gemm_desc.B.layout, + gemm_desc.transform_B, + ptr_B, + int(problem_.ldb[group_idx]), + problem_.beta.data(), + gemm_desc.C.element, + gemm_desc.C.layout, + ptr_C, + int(problem_.ldc[group_idx]), + gemm_desc.D.element, + gemm_desc.D.layout, + ptr_D, + int(problem_.ldc[group_idx]), + 1, + gemm_workspace_.A_ptr_array_host[group_idx]->batch_stride(), + gemm_workspace_.B_ptr_array_host[group_idx]->batch_stride(), + gemm_workspace_.C_ptr_array_host[group_idx]->batch_stride(), + gemm_workspace_.reference_ptr_array_host[group_idx]->batch_stride()); + + if (status != Status::kSuccess) + break; + + if (provider == library::Provider::kReferenceHost) { + gemm_workspace_.reference_ptr_array_host[group_idx]->copy_from_host(ptr_D); + } + + disposition = compare_tensors( + options, + *gemm_workspace_.D_ptr_array_host[group_idx], + *gemm_workspace_.reference_ptr_array_host[group_idx], + gemm_workspace_.D_ptr_array_host[group_idx]->batch_stride()); + if (disposition != Disposition::kPassed) + break; + } + if (status != Status::kSuccess) { + results_.back().verification_map[provider] = Disposition::kNotRun; + continue; + } + results_.back().status = status; + results_.back().verification_map[provider] = disposition; + + // Save workspace if incorrect + if ( + options.verification.save_workspace == SaveWorkspace::kIncorrect && + results_.back().verification_map[provider] == Disposition::kIncorrect) { + + save_workspace(device_context, options, gemm_desc, library::Provider::kCUTLASS, provider); + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Measures performance results +bool GroupedGemmOperationProfiler::profile( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem) { + + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + results_.back().status = profile_cutlass_( + results_.back(), + options, + operation, + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data()); + } + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Method to profile a CUTLASS Operation +Status GroupedGemmOperationProfiler::profile_cutlass_( + PerformanceResult& result, + Options const& options, + library::Operation const* operation, + void* arguments, + void* host_workspace, + void* device_workspace) { + + // initialize gemm underlying operation to handle parallel reduction + library::Operation const* underlying_operation = operation; + + auto func = [&](cudaStream_t stream, int iteration) { + // Iterate over copies of the problem in memory + int workspace_idx = options.profiling.warmup_iterations + iteration; + int problem_idx = (workspace_idx % gemm_workspace_.problem_count); + + gemm_workspace_.arguments.ptr_A = gemm_workspace_.A_ptr_array_device[problem_idx]->data(); + gemm_workspace_.arguments.ptr_B = gemm_workspace_.B_ptr_array_device[problem_idx]->data(); + gemm_workspace_.arguments.ptr_C = gemm_workspace_.C_ptr_array_device[problem_idx]->data(); + gemm_workspace_.arguments.ptr_D = gemm_workspace_.D_ptr_array_device[problem_idx]->data(); + + return underlying_operation->run(arguments, host_workspace, device_workspace); + }; + return profile_kernel_(result, options, func); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index 5a518d7f..3f071e89 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -57,16 +57,6 @@ /////////////////////////////////////////////////////////////////////////////////////////////////// -#define CUDA_CHECK(call) \ - do { \ - cudaError_t err = call; \ - if (err != cudaSuccess) { \ - std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " code=" << err << " \"" \ - << cudaGetErrorString(err) << "\"\n"; \ - return Status::kErrorInternal; \ - } \ - } while (0) - namespace cutlass { namespace profiler { /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -304,42 +294,43 @@ std::ostream& operator<<(std::ostream& out, library::Provider provider) { return out; } -std::ostream& operator<<(std::ostream& out, library::OperationKind provider) { - if (provider == library::OperationKind::kGemm) { +std::ostream& operator<<(std::ostream& out, library::OperationKind op_kind) { + if (op_kind == library::OperationKind::kGemm) { out << "kGemm"; } - - else if (provider == library::OperationKind::kBlockScaledGemm) { + else if (op_kind == library::OperationKind::kBlockScaledGemm) { out << "kBlockScaledGemm"; } - - else if (provider == library::OperationKind::kRankK) { + else if (op_kind == library::OperationKind::kRankK) { out << "kRankK"; } - else if (provider == library::OperationKind::kRank2K) { + else if (op_kind == library::OperationKind::kRank2K) { out << "kRank2K"; } - else if (provider == library::OperationKind::kTrmm) { + else if (op_kind == library::OperationKind::kTrmm) { out << "kTrmm"; } - else if (provider == library::OperationKind::kSymm) { + else if (op_kind == library::OperationKind::kSymm) { out << "kSymm"; } - else if (provider == library::OperationKind::kConv2d) { + else if (op_kind == library::OperationKind::kConv2d) { out << "kConv2d"; } - else if (provider == library::OperationKind::kConv3d) { + else if (op_kind == library::OperationKind::kConv3d) { out << "kConv3d"; } - else if (provider == library::OperationKind::kEqGemm) { + else if (op_kind == library::OperationKind::kEqGemm) { out << "kEqGemm"; } - else if (provider == library::OperationKind::kSparseGemm) { + else if (op_kind == library::OperationKind::kSparseGemm) { out << "kSparseGemm"; } - else if (provider == library::OperationKind::kReduction) { + else if (op_kind == library::OperationKind::kReduction) { out << "kReduction"; } + else if (op_kind == library::OperationKind::kGroupedGemm) { + out << "kGroupedGemm"; + } else { out << "kInvalid"; } @@ -660,6 +651,11 @@ void OperationProfiler::save_workspace( DeviceAllocation *allocation = named_allocation.second; + if (allocation->layout() == library::LayoutTypeID::kUnknown) { + continue; // write_tensor not set up to handle DeviceAllocations initialized using + // allocate_block() + } + std::stringstream filename; filename << desc.name << "_" << library::to_string(provider) << "_"; @@ -736,15 +732,20 @@ Status predict_iters( /// CUDA graphs allows you to record the launch of large numbers of kernels without /// blocking and therefore avoids a deadlock which happens if you try to enqueue too /// many kernels behind the spinloop kernel. -Status OperationProfiler::profile_kernel_( - PerformanceResult &result, - Options const &options, - const std::function &func, - const std::vector &streams) { +Status OperationProfiler::profile_kernel_w_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams) { + auto dev_count = streams.size(); + cuda::atomic *release; - CUDA_CHECK(cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable)); - release->store(false, cuda::memory_order_release); + + if (dev_count > 1) { + CUDA_CHECK(cudaHostAlloc(&release, sizeof(*release), cudaHostAllocPortable)); + release->store(false, cuda::memory_order_release); + } std::vector timer; for (size_t i = 0; i < dev_count; ++i) { @@ -774,9 +775,11 @@ Status OperationProfiler::profile_kernel_( for (size_t i = 0; i < dev_count; ++i) { CUDA_CHECK(cudaSetDevice(options.device.device_id(i))); CUDA_CHECK(cudaStreamBeginCapture(streams[i], cudaStreamCaptureModeGlobal)); - // Halt execution until all GPUs are ready to precede. - // It allows the CPU to trigger the GPUs all start at the same time. - delay<<<1, 1, 0, streams[i]>>>(release); + if (dev_count > 1) { + // Halt execution until all GPUs are ready to precede. + // It allows the CPU to trigger the GPUs all start at the same time. + delay<<<1, 1, 0, streams[i]>>>(release); + } for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { Status status = func(i, streams[i], iteration); if (status != Status::kSuccess) { @@ -803,8 +806,10 @@ Status OperationProfiler::profile_kernel_( CUDA_CHECK(cudaGraphLaunch(graphExecs[i], streams[i])); } - // release the enqueued kernels - release->store(true, cuda::memory_order_release); + if (dev_count > 1) { + // release the enqueued kernels + release->store(true, cuda::memory_order_release); + } for (size_t i = 0; i < dev_count; ++i) { CUDA_CHECK(cudaSetDevice(options.device.device_id(i))); @@ -819,7 +824,9 @@ Status OperationProfiler::profile_kernel_( } result.runtime /= static_cast(dev_count); - CUDA_CHECK(cudaFreeHost(release)); + if (dev_count > 1) { + CUDA_CHECK(cudaFreeHost(release)); + } for (size_t i = 0; i < dev_count; ++i) { CUDA_CHECK(cudaSetDevice(options.device.device_id(i))); @@ -835,11 +842,47 @@ Status OperationProfiler::profile_kernel_( return Status::kSuccess; } -/// Method to profile GPU execution time of a kernel launched in func Status OperationProfiler::profile_kernel_( PerformanceResult &result, Options const &options, - const std::function &func, + const std::function &func, + const std::vector &streams) { + + if (options.profiling.use_cuda_graphs) { + return profile_kernel_w_cuda_graphs_(result, options, func, streams); + } + else if (streams.size() == 1) { + auto single_device_func = [&](cudaStream_t stream, int iteration) { + return func(0, stream, iteration); + }; + return profile_kernel_no_cuda_graphs_(result, options, single_device_func, streams[0]); + } + return Status::kErrorNotSupported; +} + +/// Method to profile GPU execution time of a kernel launched in func +Status OperationProfiler::profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream) { + + if (options.profiling.use_cuda_graphs) { + auto graph_func = [&](int dev_id, cudaStream_t stream, int iteration) { + return func(stream, iteration); + }; + return profile_kernel_w_cuda_graphs_(result, options, graph_func, {stream}); + } else { + return profile_kernel_no_cuda_graphs_(result, options, func, stream); + } + return Status::kSuccess; +} + +/// Method to profile GPU execution time of a kernel launched in func +Status OperationProfiler::profile_kernel_no_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, cudaStream_t stream) { GpuTimer timer; diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 0adc2340..25e19493 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -477,6 +477,7 @@ Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) { cmdline.get_cmd_line_argument("profiling-enabled", enabled, true); cmdline.get_cmd_line_argument("profiling-duration", duration, 10); cmdline.get_cmd_line_argument("min-iterations", min_iterations, 10); + cmdline.get_cmd_line_argument("use-cuda-graphs", use_cuda_graphs, false); if (cmdline.check_cmd_line_flag("providers")) { diff --git a/tools/profiler/src/problem_space.cpp b/tools/profiler/src/problem_space.cpp index 0d8ade05..60d6b51f 100644 --- a/tools/profiler/src/problem_space.cpp +++ b/tools/profiler/src/problem_space.cpp @@ -1203,6 +1203,34 @@ bool arg_as_scalar( return arg_as_scalar(bytes, numeric_type, value_ptr); } +/// Returns a copy of the string passed to the argument. +/// (kScalar arguments are stored as strings). +bool arg_as_string( + std::string& arg, + char const* name, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem) { + + size_t idx = problem_space.argument_index(name); + KernelArgument::Value const* value_ptr = problem.at(idx).get(); + + if (value_ptr->not_null) { + if (value_ptr->argument->description->type == ArgumentTypeID::kScalar) { + std::string const& str_value = + static_cast(value_ptr)->value; + arg = std::string(str_value); + } + else { + throw std::runtime_error( + "arg_as_string() - illegal cast. Problem space argument must be scalar"); + } + + return true; + } + + return false; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns true if a tensor description satisfies a `tensor` value diff --git a/tools/util/include/cutlass/util/device_memory.h b/tools/util/include/cutlass/util/device_memory.h index b79b3f92..44f6a467 100644 --- a/tools/util/include/cutlass/util/device_memory.h +++ b/tools/util/include/cutlass/util/device_memory.h @@ -56,9 +56,7 @@ template T* allocate(size_t count = 1) { T* ptr = 0; - size_t bytes = 0; - - bytes = count * sizeof(T); + size_t bytes = count * sizeof_bits::value / 8; cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); diff --git a/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp new file mode 100644 index 00000000..68f824eb --- /dev/null +++ b/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -0,0 +1,480 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for mixed input data type kernels. +*/ + +#pragma once + +#include +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/mma_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cute/util/type_traits.hpp" + +namespace cutlass { + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleBroadCastLayout, + class ThrLayout> +__global__ void dequantize_kernel(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleBroadCastLayout const broadcasted_scale_layout, + ThrLayout thr_layout) { + using namespace cute; + + // Represent the full tensors to gmem elements. + // These are expected to have shape [MN, K, L] + cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); + auto init_quantized_iterator = [&]() { + if constexpr (cute::sizeof_bits_v >= 8) { + return cute::make_gmem_ptr(q_buffer); + } + else { + return cute::subbyte_iterator(q_buffer); + } + }; + cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout); + // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting + // It is expected that K % G == 0 + cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); + + // Assign 1 thread per element in the thread block + auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); // + auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) + + // Tile across the block + auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); + auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); + auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); + + auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); + auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); + auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); + auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); + + // Make a fragment of registers to hold gmem loads + cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); + cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); + cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); + cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); + cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); + + cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); + auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); + auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); + + const auto num_iters = cute::size<3>(tOpDq_gOpDq); + + for (int ii = 0; ii < num_iters; ++ii) { + const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); + if (thread_offset < cute::size<0>(operand_layout)) { + cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); + cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); + cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); + cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); + cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); + cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{}); + cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{}); + cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); + cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); + } + } +} + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout> +static void dequantize(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + int const group_size, + cudaStream_t &stream) { + using namespace cute; + + constexpr int tpb = 128; + auto thr_layout = make_layout(make_shape(Int{})); + + const auto num_rows = get<0>(shape(operand_layout)); + const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] + const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] + const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] + + if (num_rows != size<0>(scale_layout)) { + std::cerr << "Invalid first dimension for scales. Must match first dim for weights." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + const auto scale_stride0 = get<0>(stride(scale_layout)); + const auto scale_stride1 = get<1>(stride(scale_layout)); + const auto scale_stride2 = get<2>(stride(scale_layout)); + + auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); + auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); + auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); + + const auto blocks_x = gemm_k; + const auto blocks_y = batches; + + dim3 blocks(blocks_x, blocks_y, 1); + dequantize_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) { + + using StorageType = cutlass::int4b_t::Storage; + constexpr int pack = cute::sizeof_bits_v / 4; + const size_t host_buf_size = block_size / pack; + std::vector host_buf(host_buf_size); + cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size); + + for (auto&& d : host_buf) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; i++) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size); + return true; +} + +template +static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array *block_out, const size_t block_size) { + std::vector data_in(block_size); + std::vector> data_out(block_size); + + try { + cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + + for (size_t i = 0; i < block_size; i++) { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + } + + try { + cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} + +template +struct UnderlyingElement { + using type = T; +}; + +template +struct UnderlyingElement> { + using type = typename T::Element; +}; + +// Given a type of MMA instruction, compute a memory reordering atom that places all values +// owned by each thread in contiguous memory locations. This improves smem load vectorization, +// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order +// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. +// In addition, we can reorder the values across several MMA instructions to get even wider +// vectorization (AtomLayout parameter) and permute the values within each instruction to get +// more optimal conversion instruction sequences (ValLayout parameter). +template , + class ValLayout = cute::Layout> +constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) +{ + using namespace cute; + + static_assert(is_static_v, "ValLayout must be static"); + static_assert(is_static_v, "AtomLayout must be static"); + + // 1. Choose an MMA atom to access TV layout and MN shape + // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary + using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); + using MmaTraits = MMA_Traits; + auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); + auto tv_layout_mma = typename MmaTraits::ALayout{}; + static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); + + // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) + // Note: this assumes A is partitioned between warps along M mode + auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); + auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); + auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); + auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); + + // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization + auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); + + // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) + auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); + auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); + auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); + auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); + + return layout_atom; +} + +template +__global__ void reorder_tensor_kernel( + cute::Tensor S, + cute::Tensor D, + TiledCopy tiled_copy) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + + Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + + auto thread_copy = tiled_copy.get_slice(threadIdx.x); + Tensor tS = thread_copy.partition_S(gS); + Tensor tD = thread_copy.partition_D(gD); + + copy(tiled_copy, tS, tD); +} + +template +void reorder_tensor( + cute::Tensor S, + cute::Tensor D) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + static_assert(is_same_v, T>, "Type mismatch"); + + // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread + // This avoids a race condition when writing out subbyte types (e.g. int4b_t). + auto has_major_mode = [](auto s) { + return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; }); + }; + static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), + "Could not find stride-1 mode in destination layout"); + constexpr int N = shape_div(Int<8>{}, sizeof_bits{}); + auto val_layout = conditional_return(LayoutDst{}))>( + make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), + make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); + + // Make a tiled copy with a simple row-major thread order and above layout + int constexpr NumThreads = 128; + auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); + auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); + + // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper + using TileShape = Shape<_16>; + auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); + dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; + + reorder_tensor_kernel<<>>(S, D, tiled_copy); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +// In-place version +template +void reorder_tensor( + T const* src, + LayoutSrc const& layout_src, + T * dst, + LayoutDst const& layout_dst) +{ + using namespace cute; + reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), + make_tensor(make_gmem_ptr(dst), layout_dst)); +} + +// In-place version +template +void reorder_tensor( + T * data, + LayoutSrc const& layout_src, + LayoutDst const& layout_dst) +{ + using namespace cute; + cutlass::DeviceAllocation temp(size(layout_src)); + reorder_tensor(data, layout_src, temp.get(), layout_dst); + cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); +} + +#undef CUDA_CHECK + +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 534f5462..45be9e72 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -388,7 +388,7 @@ void compute_1d_scaling_factor_and_quantized_output( absolute_value_op abs_op; maximum_with_nan_propogation max_op; - if constexpr (cute::is_constant<1, decltype(cute::stride<0,1>(tensor_SfD))>::value) { + if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) { // MN major output int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize); // Col major output @@ -705,7 +705,7 @@ void gett_epilogue( if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { // Convert every type to ElementCompute first, do compute, convert to output type, write it out ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - // per-row alpha + // vector alpha if (raw_pointer_cast(epilogue_params.Valpha.data())) { converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l)); converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); @@ -719,7 +719,7 @@ void gett_epilogue( if (raw_pointer_cast(epilogue_params.C.data())) { ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); - // per-row beta + // vector beta if (epilogue_params.Vbeta.data()) { converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l)); converted_beta = mul(converted_beta, converted_scale_c);