diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index a7733fab..4d910522 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -55,7 +55,7 @@ jobs: if: | (startsWith(github.event.comment.body, '/bot run') || startsWith(github.event.comment.body, '/bot kill')) && contains( - fromJson('["zekunf-nv"]'), + fromJson('["nv-fastkernels-cicd", "zekunf-nv", "hwu36", "IonThruster", "thakkarV", "d-k-b", "mihir-awatramani", "fengxie", "vickiw973", "Junkai-Wu", "brandon-yujie-sun", "lijingticy22", "hongw-nv", "vikgupta-nv", "IwakuraRein", "depaulmillz", "jackkosaian", "itramble", "ccecka", "sxtyzhangzk", "hbarclay", "yzhaiustc", "x86vk", "sklevtsov-nvidia", "ANIKET-SHIVAM", "Shreya-gaur", "azhurkevich", "serifyesil", "richardmcai", "lsyyy666", "Ethan-Yan27", "XiaoSong9905", "shdetect", "keithzzzzz"]'), github.actor) steps: - name: Check if comment is issued by authorized person diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c139971..7a897f95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,55 @@ # CUTLASS 4.x +## [4.2.0](https://github.com/NVIDIA/cutlass/tree/main) (2025-08-21) + +### CuTe DSL +* We will likely be skipping 4.2.dev release and directly target 4.2. +* CuTeDSL version remains at 4.1.0 till then. + +### CUTLASS C++ +* Add K major scale factor support for Hopper SM90 blockwise kernels. +* Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). + - Add fused reduction kernel support for cutlass MLA. + - Fix an issue where `get_unmasked_trip_count` may return a negative value. + - Fix an issue where mbarriers are initialized with a zero arrival count. +* Add Blackwell SM120 blockwise gemm kernel example: [example 87](https://github.com/NVIDIA/cutlass/tree/main/87_blackwell_geforce_gemm_blockwise/). +* Support for Blackwell SM100 cpasync kernel. + - Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp). + - Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp). +* Support for Blackwell SM121 kernels for DGX Spark GPUs. + - Share the major codes with Blackwell SM120 kernels. +* Support for Blackwell SM100 legacy mixed input GEMM kernels. + - Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp). + - Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp). + - Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/). +* Support for Blackwell SM100 fp4 gemv kernels. + - Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h). + - Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/) +* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110. + - For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs. + - For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid. +* CuTe changes: + - Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/). + - Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support. + - Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels. + - Support fp16 accmulator for sm89 fp8 mma. + - Shorten `nullspace` implementation. + - Isolate and comment on `cosize` hacks. + - Important documentation correction: `E<0,1> == 1@0@1`. +* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics`. + - Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md). +* Rename legacy Python API package from `cutlass` to `cutlass_cppgen`. +* Fix some profiler issues: + - Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line. + - Fix some no output and timeout issues. +* Add following unit tests: + - [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu) + - [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu) + - [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu) +* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! +* Optimal code generation with CUDA toolkit versions 13.0. + ## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16) ### CuTe DSL @@ -10,7 +59,7 @@ - [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py) - [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py) * API updates - - Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details ### CUTLASS C++ * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). @@ -58,7 +107,7 @@ - [C-structure based customized interface between JIT function and user codes](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/ffi/jit_argument.py) * [Educational notebooks for getting started with CuTe DSL](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/notebooks) * API updates - - Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details + - Please refer to [DSL API changelog](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_api/changelog.html) for details ### CUTLASS C++ * Support [Family Specific Architecture Features](https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/) which was introduced in CUDA 12.9 diff --git a/CMakeLists.txt b/CMakeLists.txt index 4088b71f..29fb4e21 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -175,13 +175,25 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.8) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100 100a 120 120a 121 121a) + if (CUDA_VERSION VERSION_LESS 13.0) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101 101a) + else() + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a) + endif() endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.9) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 100f 120f 121f 103a 103f) + if (CUDA_VERSION VERSION_LESS 13.0) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 101f) + else() + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110f) + endif() +endif() + +if (CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 110 110a) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") @@ -288,18 +300,50 @@ if (KERNEL_FILTER_FILE) set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE) endif() +if (CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE) + get_filename_component(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" ABSOLUTE) + set(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" CACHE STRING "HEURISTICS FILE FULL PATH" FORCE) +endif() + set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list") if(KERNEL_FILTER_FILE) message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}") endif() +if(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE) + message(STATUS "Full path of heuristics problems file: ${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}") + if(DEFINED CUTLASS_NVMMH_URL) + message(STATUS "CUTLASS_NVVMH_URL is set. Fetching dependency") + include(FetchContent) + FetchContent_Declare( + nvmmh + URL ${CUTLASS_NVMMH_URL} + ) + FetchContent_MakeAvailable(nvmmh) + FetchContent_GetProperties(nvmmh SOURCE_DIR nvmmh_dir) + set(CUTLASS_NVMMH_PATH "${nvmmh_dir}") + endif() + + if(DEFINED CUTLASS_NVMMH_PATH) + message(STATUS "CUTLASS_NVMMH_PATH is set. Using package at: ${CUTLASS_NVMMH_PATH}") + + set(CUTLASS_NVMMH_PY_DIR "${CUTLASS_NVMMH_PATH}/python/") + set(ENV{CUTLASS_NVMMH_SO_PATH} "${CUTLASS_NVMMH_PATH}/lib/libnvMatmulHeuristics.so") + endif() +endif() + set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.") set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.") set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.") set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).") set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") +if(CUTLASS_LIBRARY_INSTANTIATION_LEVEL OR CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE) + message(STATUS "Enable extended SM90 WGMMA instruction shapes for instantiation levels") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +endif() + ################################################################################ set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests") @@ -350,6 +394,10 @@ if (CUTLASS_NVCC_ARCHS MATCHES 100f OR CUTLASS_NVCC_ARCHS MATCHES 101f) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED) endif() +if (CUTLASS_NVCC_ARCHS MATCHES 110f) +list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SM100_FAMILY_ARCHS_ENABLED) +endif() + set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace") # @@ -428,8 +476,6 @@ endif() # ################################################################################################### - - # Warnings-as-error exceptions and warning suppressions for Clang builds if (CUTLASS_CLANG_HOST_COMPILE) @@ -705,6 +751,7 @@ target_include_directories( SYSTEM INTERFACE $ ) + if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0) target_include_directories( CUTLASS diff --git a/FUNCTIONALITY.md b/FUNCTIONALITY.md deleted file mode 100644 index a038b0cb..00000000 --- a/FUNCTIONALITY.md +++ /dev/null @@ -1,30 +0,0 @@ -# Changelog for CuTe DSL API changes - -## [4.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.1.0) (2025-07-16) - - * for loop - - Python built-in ``range`` now always generates IR and executes at runtime - - ``cutlass.range`` is advanced ``range`` with IR level unrolling and pipelining control - - Deprecated ``cutlass.range_dynamic``, please replace with ``range`` or ``cutlass.range`` - - **Experimental** Added ``pipelining`` control for compiler generated software pipeline code - * while/if - - ``while``/``if`` now by default generates IR and executes at runtime unless ``cutlass.const_expr`` is specified for the predicate - - Deprecated ``cutlass.dynamic_expr``, please remove it - * Rename mbarrier functions to reduce ambiguity - * Modify SyncObject API (`MbarrierArray`, `NamedBarrier`, `TmaStoreFence`) to match `std::barrier` - * Change pipeline `create` function to take only keyword arguments, and make `barrier_storage` optional. - * Introduce `cutlass.cute.arch.get_dyn_smem_size` api to get runtime dynamic shared memory size. - * Various API Support for SM100 BlockScaled Gemm - - Introduce BlockScaled MmaOps in [tcgen05/mma.py]([https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py]), and provide a `make_blockscaled_trivial_tiled_mma` function in [blackwell_helpers.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blackwell_helpers.py) to help construct a BlockScaled TiledMma. - - Introduce S2T CopyOps in [tcgen05/copy.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py). - - Introduce BlockScaled layout utilities in [blockscaled_layout.py](https://github.com/NVIDIA/cutlass/blob/main/python/CuTeDSL/cutlass/utils/blockscaled_layout.py) for creating the required scale factor layouts in global memory, shared memory and tensor memory. - * `cutlass.cute.compile` now supports compilation options. Refer to [JIT compilation options](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_jit_compilation_options.html) for more details. - * `cutlass.cute.testing.assert_` now works for device JIT function. Specify `--enable-device-assertions` as compilation option to enable. - * `cutlass.cute.make_tiled_copy` is now deprecated. Please use `cutlass.cute.make_tiled_copy_tv` instead. - * Shared memory capacity query - - Introduce `cutlass.utils.get_smem_capacity_in_bytes` for querying the shared memory capacity. - - `_utils.SMEM_CAPACITY[""]` is now deprecated. - -## [4.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.0.0) (2025-06-03) - - * Fixed API mismatch in class ``cute.runtime.Pointer``: change ``element_type`` to ``dtype`` to match ``typing.Pointer`` diff --git a/README.md b/README.md index 6b0edce8..06db5899 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.1.0 +# CUTLASS 4.2.0 -_CUTLASS 4.1.0 - July 2025_ +_CUTLASS 4.2.0 - Aug 2025_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -43,40 +43,52 @@ To get started quickly - please refer : - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/cpp/quickstart.html). - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/quick_start.html). -# What's New in CUTLASS 4.1 +# What's New in CUTLASS 4.2 ## CuTe DSL -* Add aarch64 support, you can now pip install `nvidia-cutlass-dsl` on GB200 systems! -* More examples demonstrating how to use CuTe DSL to write peak-performance kernels - - [Blackwell Mamba2 SSD](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py) - - [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py) -* API updates - - Please refer to [FUNCTIONALITY.md](https://github.com/NVIDIA/cutlass/blob/main/FUNCTIONALITY.md) for details +* We will likely be skipping 4.2.dev release and directly target 4.2. +* CuTeDSL version remains at 4.1.0 till then. ## CUTLASS C++ +* Add K major scale factor support for Hopper SM90 blockwise kernels. * Further enhance Blackwell SM100 Attention kernels in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). - - Add variable sequence length support for FMHA Backward kernel. - - Add varlen test support to Backward runner. - - Codes support empty batch sequences. -* Replace `subbyte_iterator` with `cute::recast_ptr` when constructing logical iterators/arrays. + - Add fused reduction kernel support for cutlass MLA. + - Fix an issue where `get_unmasked_trip_count` may return a negative value. + - Fix an issue where mbarriers are initialized with a zero arrival count. +* Add Blackwell SM120 blockwise gemm kernel example: [example 87](https://github.com/NVIDIA/cutlass/tree/main/87_blackwell_geforce_gemm_blockwise/). +* Support for Blackwell SM100 cpasync kernel. + - Collective mainloop codes: [cpasync mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp). + - Kernel codes: [cpasync kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp). +* Support for Blackwell SM121 kernels for DGX Spark GPUs. + - Share the major codes with Blackwell SM120 kernels. +* Support for Blackwell SM100 legacy mixed input GEMM kernels. + - Collective mainloop codes: [Mixed input mainloop](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp). + - Kernel codes: [Mixed input kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp). + - Example codes: [example 86](https://github.com/NVIDIA/cutlass/tree/main/examples/86_blackwell_mixed_dtype_gemm/). +* Support for Blackwell SM100 fp4 gemv kernels. + - Kernel codes: [Gemv kernel](https://github.com/NVIDIA/cutlass/tree/main/include/cutlass/gemm/kernel/gemv_blockscaled.h). + - Example codes: [example 91](https://github.com/NVIDIA/cutlass/tree/main/examples/91_fp4_gemv/) +* From CUDA 13.0, the Blackwell SM101 for Thor GPUs is renamed to SM110. + - For CUDA toolkit version < 13.0, SM101 is still used for Thor GPUs. + - For CUDA toolkit version >= 13.0, SM110 is used for Thor GPUs and SM101 is no longer valid. * CuTe changes: - - Rewrite ArithTuple and ScaledBasis for robustness and clarity. - - Remove buggy and kludgy `get_layoutA|B|C_MN` and friends from Atoms/TiledX. - - Factor out `print_latex` and friends and rewrite. - - Factor out `print_svg` and friends and rewrite. -* Support Blackwell SM100 SIMT packed fp32x2 kernels. -* Support residual add for implicit gemm kernels. -* Various fixes for CUTLASS C++ Python interface's EVT tracer: - - Add verifier for sm90 to report the invalid input. - - When adding an edge to the graph, if the edge already exists, add an identity compute node to avoid having multiple parallel edges. - - Register operations of tanh, sigmoid, exp, gelu to the python ast frontend. - - Replace the NotImplemented Error by packing all nodes into a single topological visitor node as a fallback. -* Fix profiler bugs in exhaustive perf search. - - Fix incorrect cluster shape output issue when doing exhaustive search. - - Fix a bug in profiler grouped GEMM for setting tile scheduler swizzles, cluster shapes, and raster orders. -* Fix some profiler issues. - - Complete the reference for Blackwell blockwise gemm kernels. - - Fix incorrect regex logic for L1 test. + - Fix inaccurate GridDim calculation under [CuTe tutorial](https://github.com/NVIDIA/cutlass/tree/main/examples/cute/tutorial/blackwell/). + - Add [movmatrix](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-movmatrix) support. + - Fix smallest MMA-N allowed for Blackwell fp8 and fp16 gemm kernels. + - Support fp16 accmulator for sm89 fp8 mma. + - Shorten `nullspace` implementation. + - Isolate and comment on `cosize` hacks. + - Important documentation correction: `E<0,1> == 1@0@1`. +* Add support for heuristics-based kernel filtering and autotuning using `nvidia-matmul-heuristics`. + - Details please refer to [heuristics doc](https://github.com/NVIDIA/cutlass/tree/main/media/docs/cpp/heuristics.md). +* Rename legacy Python API package from `cutlass` to `cutlass_cppgen`. +* Fix some profiler issues: + - Modify default cluster callback values to none 0 to avoid profiler failure when these values are not set in command line. + - Fix some no output and timeout issues. +* Add following unit tests: + - [fp16 accmulator for sm89 fp8 mma](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/ampere/cooperative_gemm.cu) + - [movmatrix test](https://github.com/NVIDIA/cutlass/tree/main/test/unit/cute/turing/movm.cu) + - [fp8 narrow mma n](https://github.com/NVIDIA/cutlass/tree/main/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu) and [fp16 narrow mma n](test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu) Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. diff --git a/customConfigs.cmake b/customConfigs.cmake index a7342044..ac86cbe1 100644 --- a/customConfigs.cmake +++ b/customConfigs.cmake @@ -65,7 +65,12 @@ endfunction() if(CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS) - set(PROFILER_ARCH_LIST 100a 100f 101a 101f 120a 120f) + set(PROFILER_ARCH_LIST 100a 100f 103a 120a 120f 121a) + if (CUDA_VERSION VERSION_LESS 13.0) + list(APPEND PROFILER_ARCH_LIST 101a 101f) + else() + list(APPEND PROFILER_ARCH_LIST 110a 110f) + endif() foreach(ARCH IN LISTS CUTLASS_NVCC_ARCHS) if(NOT (ARCH IN_LIST PROFILER_ARCH_LIST)) message(FATAL_ERROR "Only SM${PROFILER_ARCH_LIST} compute capabilities are supported with profiler-based unit tests") diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 0e221452..c07b1ea9 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -659,7 +659,7 @@ struct Testbed { } int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; - int64_t bytes = cutlass::bits_to_bytes( + int64_t bytes = cutlass::bits_to_bytes( (cutlass::sizeof_bits::value * 2 + cutlass::sizeof_bits::value) * options.problem_size.m() * options.problem_size.n()); diff --git a/examples/39_gemm_permute/layouts.h b/examples/39_gemm_permute/layouts.h index 5ffb04fd..d4d9ed31 100644 --- a/examples/39_gemm_permute/layouts.h +++ b/examples/39_gemm_permute/layouts.h @@ -33,8 +33,8 @@ computing reference permutations of 4/5D tensors when source data is column-major. */ #pragma once -#include #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" #include "cutlass/coord.h" diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h index 9ed17f4b..84f89a5f 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -40,14 +40,12 @@ Note that in general the fragment passed to the OutputOp could span multiple rows but it does not happen with the configurations we have */ - #pragma once -#include - #include "cutlass/aligned_buffer.h" #include "cutlass/array.h" #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/functional.h" #include "cutlass/layout/tensor.h" #include "cutlass/layout/vector.h" diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h index 973ec345..411b5574 100644 --- a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -42,12 +42,10 @@ */ #pragma once - -#include - +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/aligned_buffer.h" #include "cutlass/array.h" -#include "cutlass/cutlass.h" #include "cutlass/functional.h" #include "cutlass/layout/tensor.h" #include "cutlass/layout/vector.h" diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h index fdb96db9..04d988cc 100644 --- a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h @@ -38,10 +38,8 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/vector.h" diff --git a/examples/45_dual_gemm/threadblock/dual_epilogue.h b/examples/45_dual_gemm/threadblock/dual_epilogue.h index a234b200..bcb34d89 100644 --- a/examples/45_dual_gemm/threadblock/dual_epilogue.h +++ b/examples/45_dual_gemm/threadblock/dual_epilogue.h @@ -37,12 +37,10 @@ */ #pragma once - -#include - +#include "cutlass/array.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" -#include "cutlass/array.h" #include "cutlass/layout/vector.h" #include "cutlass/layout/tensor.h" #include "cutlass/tensor_coord.h" diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index e7e3e4ea..ab88f54d 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -132,7 +132,7 @@ constexpr int ScaleGranularityK = 128; constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; -using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu index 1977b698..b187d2da 100644 --- a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling_with_sparse_groups.cu @@ -142,7 +142,7 @@ static constexpr int ScaleGranularityK = 128; static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; -using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; +using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand diff --git a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu index 19d6b89d..f9d14215 100644 --- a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu +++ b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu @@ -454,11 +454,12 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; return 0; - } - + } + // // Parse options // diff --git a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu index d476ce00..f0b85865 100644 --- a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu +++ b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu @@ -640,11 +640,11 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; return 0; - } - + } // // Parse options diff --git a/examples/70_blackwell_gemm/CMakeLists.txt b/examples/70_blackwell_gemm/CMakeLists.txt index 0ac1687d..5bb294b1 100644 --- a/examples/70_blackwell_gemm/CMakeLists.txt +++ b/examples/70_blackwell_gemm/CMakeLists.txt @@ -33,7 +33,7 @@ set(TEST_SWIZZLE_2 --swizzle=2) set(TEST_SWIZZLE_5 --swizzle=5) set(TEST_SWIZZLE_5_UNEVEN --swizzle=5 --m=4096 --n=16384) -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") cutlass_example_add_executable( 70_blackwell_fp16_gemm 70_blackwell_fp16_gemm.cu diff --git a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu index f911262f..6e25c971 100644 --- a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu +++ b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu @@ -449,9 +449,9 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } - - if (!(props.major == 10 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; return 0; } diff --git a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt index a326f461..87bcd514 100644 --- a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt +++ b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # Both filenames are shorter to avoid MAX_PATH issues on Windows. -if (CUTLASS_NVCC_ARCHS MATCHES 100a) +if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") cutlass_example_add_executable( 71_blackwell_gemm_with_collective_builder 71_blackwell_gemm_with_collective_builder.cu diff --git a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu index 403472ad..390012f2 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu @@ -116,7 +116,7 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O // Kernel Perf config using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size -using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster +using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -511,10 +511,10 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 10 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; return 0; - } + } // // Parse options diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu index 6c28c552..e3ad25fe 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu @@ -566,8 +566,8 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 10 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; return 0; } diff --git a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu index aff311c1..e157c6ca 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu @@ -117,7 +117,7 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O // Kernel Perf config using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size -using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster +using ClusterShape = Shape<_2,_4,_1>; // Shape of the threadblocks in a cluster using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -512,8 +512,8 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 10 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; return 0; } diff --git a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt index eaeb6600..45ccdcca 100644 --- a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt +++ b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt @@ -28,7 +28,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if (CUTLASS_NVCC_ARCHS MATCHES 100a) +if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") cutlass_example_add_executable( 72a_blackwell_nvfp4_bf16_gemm 72a_blackwell_nvfp4_bf16_gemm.cu diff --git a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt index a4a18324..4ac31f62 100644 --- a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt +++ b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt @@ -28,7 +28,7 @@ -if (CUTLASS_NVCC_ARCHS MATCHES 100a) +if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") cutlass_example_add_executable( 73_blackwell_gemm_preferred_cluster blackwell_gemm_preferred_cluster.cu diff --git a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu index 67b82a6e..df805051 100644 --- a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu +++ b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu @@ -513,7 +513,7 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); if (props.major != 10 || props.minor != 0) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100)." << std::endl; + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; return 0; } diff --git a/examples/74_blackwell_gemm_streamk/CMakeLists.txt b/examples/74_blackwell_gemm_streamk/CMakeLists.txt index 5a378241..4f808e85 100644 --- a/examples/74_blackwell_gemm_streamk/CMakeLists.txt +++ b/examples/74_blackwell_gemm_streamk/CMakeLists.txt @@ -29,9 +29,9 @@ -if (CUTLASS_NVCC_ARCHS MATCHES 100a) -cutlass_example_add_executable( - 74_blackwell_gemm_streamk - blackwell_gemm_streamk.cu +if(CUTLASS_NVCC_ARCHS STREQUAL "100a" OR CUTLASS_NVCC_ARCHS STREQUAL "100f" OR CUTLASS_NVCC_ARCHS STREQUAL "101a" OR CUTLASS_NVCC_ARCHS STREQUAL "101f" OR CUTLASS_NVCC_ARCHS STREQUAL "103a" OR CUTLASS_NVCC_ARCHS STREQUAL "103f") + cutlass_example_add_executable( + 74_blackwell_gemm_streamk + blackwell_gemm_streamk.cu ) endif() diff --git a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu index add938a9..31e5c2e0 100644 --- a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu +++ b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu @@ -556,10 +556,19 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; - return 0; + if (__CUDACC_VER_MAJOR__ < 13) { + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; + return 0; + } } + else { + if ((props.major != 10 || props.major != 11) && props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl; + return 0; + } + } + // // Parse options // diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu index 097a8693..84c42b91 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu @@ -762,9 +762,8 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (!(props.major == 10 && props.minor == 0)) { - std::cerr - << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n"; + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; return 0; } diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu index f363dfa0..a18828e2 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -138,8 +138,7 @@ using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor // Core kernel configurations using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature -using EpilogueOperatorClass = cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag -using MainloopOperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size // Runtime Cluster Shape @@ -159,7 +158,7 @@ struct MMA2SMConfig { }; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, EpilogueOperatorClass, + ArchTag, OperatorClass, typename MMA1SMConfig::MmaTileShape, ClusterShape, Shape<_128,_64>, ElementAccumulator, ElementAccumulator, @@ -169,7 +168,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui // , FusionOperation // Enable for SF Output >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, MainloopOperatorClass, + ArchTag, OperatorClass, ElementA, LayoutA *, AlignmentA, ElementB, LayoutB *, AlignmentB, ElementAccumulator, @@ -187,7 +186,7 @@ using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; using Gemm = Gemm1SM; using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, EpilogueOperatorClass, + ArchTag, OperatorClass, typename MMA2SMConfig::MmaTileShape, ClusterShape, Shape<_128,_64>, ElementAccumulator, ElementAccumulator, @@ -197,13 +196,13 @@ using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::Collective // , FusionOperation // Enable for SF Output >::CollectiveOp; using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, MainloopOperatorClass, + ArchTag, OperatorClass, ElementA, LayoutA *, AlignmentA, ElementB, LayoutB *, AlignmentB, ElementAccumulator, typename MMA2SMConfig::MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + static_cast(sizeof(typename CollectiveEpilogue2SM::SharedStorage))>, typename MMA2SMConfig::KernelSchedule >::CollectiveOp; using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal< @@ -233,7 +232,7 @@ using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF; std::vector stride_A_host; std::vector stride_B_host; std::vector layout_SFA_host; -std::vector layout_SFB_host; +std::vector layout_SFB_host; std::vector stride_C_host; std::vector stride_D_host; @@ -897,9 +896,8 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (!(props.major == 10 && props.minor == 0)) { - std::cerr - << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 100a).\n"; + if (props.major != 10 || (props.minor != 0 && props.minor != 1 && props.minor != 3)) { + std::cerr << "This example requires a GPU with compute capability 100a|f, 101a|f, or 103a|f)." << std::endl; return 0; } diff --git a/examples/75_blackwell_grouped_gemm/CMakeLists.txt b/examples/75_blackwell_grouped_gemm/CMakeLists.txt index 0ce48662..304a49f8 100644 --- a/examples/75_blackwell_grouped_gemm/CMakeLists.txt +++ b/examples/75_blackwell_grouped_gemm/CMakeLists.txt @@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes -if (CUTLASS_NVCC_ARCHS MATCHES 100a) +if(CUTLASS_NVCC_ARCHS STREQUAL "100a") cutlass_example_add_executable( 75_blackwell_grouped_gemm 75_blackwell_grouped_gemm.cu diff --git a/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu index 91511235..f548e890 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu @@ -504,10 +504,19 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; - return 0; + if (__CUDACC_VER_MAJOR__ < 13) { + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; + return 0; + } } + else { + if ((props.major != 10 || props.major != 11) && props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl; + return 0; + } + } + // // Parse options // diff --git a/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu index 15ec02aa..49da2af6 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_fprop.cu @@ -504,10 +504,19 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; - return 0; + if (__CUDACC_VER_MAJOR__ < 13) { + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; + return 0; + } } + else { + if ((props.major != 10 || props.major != 11) && props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl; + return 0; + } + } + // // Parse options // diff --git a/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu index e47dbece..a491bed8 100644 --- a/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu +++ b/examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu @@ -500,10 +500,19 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; - return 0; - } + if (__CUDACC_VER_MAJOR__ < 13) { + if (props.major != 10 && (props.minor != 0 || props.minor != 1)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 101)." << std::endl; + return 0; + } + } + else { + if ((props.major != 10 || props.major != 11) && props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 100 or 110)." << std::endl; + return 0; + } + } + // // Parse options // diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index edaf76f9..c848fcfa 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -163,7 +163,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla.cu TEST_COMMAND_OPTIONS TEST_MLA_BASIC - TEST_MLA_SEP_REDUCTION + TEST_MLA_SEP_REDUCTION TEST_MLA_FUSE_REDUCTION ) target_include_directories(77_blackwell_mla_2sm_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) @@ -175,8 +175,8 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla.cu TEST_COMMAND_OPTIONS TEST_MLA_BASIC - TEST_MLA_SEP_REDUCTION - TEST_MLA_FUSE_REDUCTION + TEST_MLA_SEP_REDUCTION + TEST_MLA_FUSE_REDUCTION ) target_include_directories(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_mla_2sm_cpasync_${PREC} PRIVATE ${PREC_MACRO} CPASYNC) diff --git a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp index 2326c641..5c5de849 100644 --- a/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -100,7 +100,7 @@ public: cutlass::fmha::kernel::FmhaKernelBwdConvert >; - using OperationMha= cutlass::fmha::device::FMHA< + using OperationNormal= cutlass::fmha::device::FMHA< cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< ProblemShape, Element, ElementAccumulator, TileShape, Mask > @@ -112,7 +112,7 @@ public: > >; - using Operation = std::conditional_t; + using Operation = std::conditional_t; using Kernel = typename Operation::Kernel; diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp index 92c7d371..5fd8a53c 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -365,7 +365,7 @@ struct Sm100FmhaGenKernelWarpspecialized { pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; } pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; - pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp); typename CollectiveMainloop::PipelineE pipeline_corr_epi( shared_storage.pipelines.corr_epi, pipeline_corr_epi_params, diff --git a/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu index f50e85b4..cd4231c0 100644 --- a/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu +++ b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu @@ -117,7 +117,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, - cutlass::epilogue::NoSmemWarpSpecialized2Sm + cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm >::CollectiveOp; // Build the mainloop diff --git a/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu b/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu index 7e186c72..f2f585b8 100644 --- a/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79a_blackwell_geforce_nvfp4_bf16_gemm.cu @@ -88,7 +88,7 @@ using namespace cute; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -189,7 +189,7 @@ cutlass::HostTensor block_C; cutlass::HostTensor block_D; // Reference Output Tensor cutlass::HostTensor block_reference_D; -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) template auto make_iterator(T* ptr) { @@ -283,7 +283,7 @@ struct Result }; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -489,19 +489,28 @@ int run(Options &options) return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example - // and must have compute capability at least 100. + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { - std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif cudaDeviceProp props; int current_device_id; @@ -509,8 +518,8 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 12 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl; return 0; } @@ -530,9 +539,9 @@ int main(int argc, char const **args) { // // Evaluate CUTLASS kernels // -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) run(options); -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) return 0; } diff --git a/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu b/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu index a0aa01c6..d929823b 100644 --- a/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79b_blackwell_geforce_nvfp4_nvfp4_gemm.cu @@ -86,7 +86,7 @@ using namespace cute; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -217,7 +217,7 @@ cutlass::HostTensor block_refer // Matrix-wide normalization constant cutlass::HostTensor block_Normconst; -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) template auto make_iterator(T* ptr) { @@ -311,7 +311,7 @@ struct Result }; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -536,19 +536,28 @@ int run(Options &options) return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example - // and must have compute capability at least 100. + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { - std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif cudaDeviceProp props; int current_device_id; @@ -556,8 +565,8 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 12 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl; return 0; } @@ -577,9 +586,9 @@ int main(int argc, char const **args) { // // Evaluate CUTLASS kernels // -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) run(options); -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) return 0; } diff --git a/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu b/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu index 655719d9..f50f14d7 100644 --- a/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79c_blackwell_geforce_mixed_mxfp8_mxfp6_bf16_gemm.cu @@ -88,7 +88,7 @@ using namespace cute; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -189,7 +189,7 @@ cutlass::HostTensor block_C; cutlass::HostTensor block_D; // Reference Output Tensor cutlass::HostTensor block_reference_D; -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) template auto make_iterator(T* ptr) { @@ -283,7 +283,7 @@ struct Result }; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -489,19 +489,28 @@ int run(Options &options) return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example - // and must have compute capability at least 100. + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { - std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif cudaDeviceProp props; int current_device_id; @@ -509,8 +518,8 @@ int main(int argc, char const **args) { CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 12 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl; return 0; } @@ -530,9 +539,9 @@ int main(int argc, char const **args) { // // Evaluate CUTLASS kernels // -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) run(options); -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) return 0; } diff --git a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu index 3342eebb..d3ebecd1 100644 --- a/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu +++ b/examples/79_blackwell_geforce_gemm/79d_blackwell_geforce_nvfp4_grouped_gemm.cu @@ -97,7 +97,7 @@ using namespace cute; using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -263,7 +263,7 @@ cutlass::DeviceAllocation block_beta; // NormConst is a single device-side constant value, its not per-batch or per-group cutlass::DeviceAllocation norm_constant_device; -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) template auto make_iterator(T* ptr) { @@ -466,7 +466,7 @@ struct Result bool passed = false; }; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation @@ -861,30 +861,39 @@ int run(Options &options, bool host_problem_shapes_available = true) return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) if (__CUDACC_VER_MAJOR__ < 12 || ((__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8) ) ) { - std::cerr << "This example requires CUDA 12.8 or newer.\n"; + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support.\n"; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); cudaError_t error = cudaGetDeviceProperties(&props, 0); - if (!(props.major == 12 && props.minor == 0)) { + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { std::cerr - << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120a).\n"; + << "This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 120 or 121).\n"; return 0; } @@ -901,7 +910,7 @@ int main(int argc, char const **args) { return 0; } -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) allocate(options); initialize(options); diff --git a/examples/79_blackwell_geforce_gemm/CMakeLists.txt b/examples/79_blackwell_geforce_gemm/CMakeLists.txt index b689c85e..a48db16d 100644 --- a/examples/79_blackwell_geforce_gemm/CMakeLists.txt +++ b/examples/79_blackwell_geforce_gemm/CMakeLists.txt @@ -46,7 +46,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes -if (CUTLASS_NVCC_ARCHS MATCHES 120a) +if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a") cutlass_example_add_executable( 79a_blackwell_geforce_nvfp4_bf16_gemm 79a_blackwell_geforce_nvfp4_bf16_gemm.cu diff --git a/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu index e7253c24..ee679f70 100644 --- a/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu +++ b/examples/80_blackwell_geforce_sparse_gemm/80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu @@ -78,7 +78,7 @@ #include "helper.h" using namespace cute; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -248,7 +248,7 @@ struct Result avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) {} }; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -507,25 +507,34 @@ int run(Options &options) } return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example - // and must have compute capability at least 120. + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { - std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 12 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl; return 0; } // @@ -540,9 +549,9 @@ int main(int argc, char const **args) { // // Evaluate CUTLASS kernels // -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) run(options); -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) return 0; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu index b5ba430e..c19a0948 100644 --- a/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu +++ b/examples/80_blackwell_geforce_sparse_gemm/80b_blackwell_geforce_nvfp4_nvfp4_sparse_gemm.cu @@ -78,7 +78,7 @@ #include "helper.h" using namespace cute; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM kernel configurations ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -183,7 +183,7 @@ cutlass::HostTensor bloc cutlass::HostTensor block_reference_D; cutlass::HostTensor block_reference_SFD; cutlass::HostTensor block_Normconst; -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) template auto make_iterator(T* ptr) { return cute::recast_ptr(ptr); @@ -259,7 +259,7 @@ struct Result avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) {} }; -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) ///////////////////////////////////////////////////////////////////////////////////////////////// /// GEMM setup and evaluation ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -531,25 +531,34 @@ int run(Options &options) } return 0; } -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { - // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit to run this example - // and must have compute capability at least 120. + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { - std::cerr << "This example requires CUDA 12.8 or newer." << std::endl; + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif cudaDeviceProp props; int current_device_id; CUDA_CHECK(cudaGetDevice(¤t_device_id)); CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); - if (!(props.major == 12 && props.minor == 0)) { - std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120)." << std::endl; + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 120 or 121)." << std::endl; return 0; } // @@ -564,9 +573,9 @@ int main(int argc, char const **args) { // // Evaluate CUTLASS kernels // -#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) run(options); -#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) return 0; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt index 6a94fb0d..2eb110e6 100644 --- a/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt +++ b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if (CUTLASS_NVCC_ARCHS MATCHES 120a) +if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a") cutlass_example_add_executable( 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu diff --git a/examples/81_blackwell_gemm_blockwise/README.md b/examples/81_blackwell_gemm_blockwise/README.md new file mode 100644 index 00000000..9fe03bab --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/README.md @@ -0,0 +1,104 @@ +# Blockwise and Groupwise GEMM and Grouped GEMM on Blackwell + +Blockwise and Groupwise GEMM and Grouped GEMM implement software scaling by the accumulator type. +The examples in this directory aim to demonstrate how we can instantiate this kernel and run it. +The profiler enables instantiating and profiling different kernel configurations for Blockwise and Groupwise GEMM +to determine the best performing kernel for your workload. + +## Introduction +Blockwise and Groupwise GEMM operations enable fine-grained numerical precision control by applying scale factors at configurable granularities. This is particularly useful for quantized neural networks where different regions of tensors may have different scaling requirements. + +For a GEMM $D = \alpha A B + \beta C$, we introduce two scale factor tensors, SFA +and SFB. This leads to a GEMM $D = \alpha \text{SFA} * A \text{ SFB} * B + \beta C$. + +## Scale Factor Tensors +- *SFA*: Broadcast the same scale within a block defined by _scale granularity m_ and _scale granularity k_ when scaling A. + - Scale granularity m and scale granularity k are also referred to as _scale vector m_ and _k_ respectively. +- *SFB*: Broadcast the same scale within a block defined by _scale granularity n_ and _scale granularity k_ when scaling B. + - Scale granularity n and scale granularity k are also referred to as _scale vector n_ and _k_ respectively. + +These can be represented in CuTe as: +- *SFA Layout*: $((\text{scale granularity M}, M / \text{scale granularity M}), (\text{scale granularity K}, K / \text{scale granularity K})) : ((0, int), (0, int))$ +- *SFB Layout*: $((\text{scale granularity N}, M / \text{scale granularity M}), (\text{scale granularity K}, K / \text{scale granularity K})) : ((0, int), (0, int))$ + +The 0 element stride ensures the same group of coordinates to map to the same element in the scale factors. + +## Configuration + +For convenience the Blockwise and Groupwise implementation provide +`cutlass::detail::Sm100BlockwiseScaleConfig` +to deduce layouts and manage compact tensors. + +`cutlass::detail::Sm100BlockwiseScaleConfig` by default makes +every tensor major the M/N mode, but can be configured. For example: +`cutlass::detail::Sm100BlockwiseScaleConfig` +denotes SFA will be major in the K dimension but SFB will be major in the N dimension. + +## Integration with Other Frameworks + +If translating from frameworks like Torch where SFA has shape +(M / ScaleGranularityM, K / ScaleGranularityK) and SFB has a shape (K / ScaleGranularityK, N / ScaleGranularityN), +ensure to transpose SFB and B to fit into the canonical CuTe layout form. This ensures K is always the second mode. +Use strides can be used to determine if each tensor is MN or K major to correctly form the layouts either directly +or with the convenience wrappers. + + +## Kernel Selection and Profiling + +To determine the most performance Blockwise/Groupwise GEMM or Grouped GEMM kernel for your use case, you can utilize the +[CUTLASS profiler](../../media/docs/cpp/profiler.md). + +All Blockwise/Groupwise GEMMs and Group GEMMs with `f32` scaling of `e4m3` or runtime `f8` types can be selected by +selecting a subset of kernels when configuring with CMake by passing: +`-DCUTLASS_LIBRARY_KERNELS="cutlass3x*f32xe4m3_*f32xe4m3*,cutlass3x*f32xf8_*f32xf8*"`. + +The simplest way to use the profiler is to pass `m`, `n`, and `k` as well as your `scale_vec_size_m`, +`scale_vec_size_n`, and `scale_vec_size_k`. Passing `enable-best-kernel-for-fixed-shape` will do some autotuning +per kernel to determine best rasterization orders, swizzles, and cluster sizes. Passing `blockwiseGemm` +or `GroupedGemm` through the operation flag will determine which set of operations will be profiled. + +For examle, this command using the cutlass profiler will dump the performance of all compiled kernels which support scale +granularity m = 1, scale granularity n = 128, and scale granularity k = 128 for the problem size 8192x8192x8192: +``` +cutlass_profiler --operation=blockwiseGemm \ + --enable-best-kernel-for-fixed-shape \ + --m=8192 --n=8192 --k=8192 \ + --scale_vec_size_m=1 --scale_vec_size_n=128 --scale_vec_size_k=128 \ + --verification-enabled=false +``` + +### Kernel Naming Convention + +The naming of the blockwise and groupwise kernels includes the following new pattern: for each tensor scalar pair we have +`xx`. For example +`cutlass3x_sm100_tensorop_gemm_64x128f32xe4m3_1x128f32xe4m3_f32_f16_f16_64x128x128_1x1x1_0_nnn_align16_1sm` would denote: +- A CUTLASS 3 GEMM for SM100 that uses tensor cores. +- SFA is f32 with a 64 element scale granularity m and a 128 element scale granularity k. +- The A matrix is e4m3. +- SFB is f32 with a 1 element scale granularity n and a 128 element scale granularity k. +- The B matrix is e4m3. +- The epilogue is done in f32. +- The C matrix is f16. +- The D matrix is f16. +- The MMA tile shape is 64x128x128. +- The cluster shape is 1x1x1. +- A, B, C, and D are all column major. +- The alignment of the major modes are 16 elements for A, B, C, and D. +- The MMA variant is a 1SM instruction. + +It is also worthwhile to note that C can be void if scaling by beta is not needed. + +## Performance Tips and Tricks + +- *MMA Dimensions*: in both Blackwell and Hopper tensor cores it is worthwhile to note that the smallest `MMA_M` dimension is 64, but `MMA_N` +dimension can be as small as 8 for some instructions. For problem sizes where M is small consider computing $D^T = \alpha B^T A^T + \beta C^T$ instead. + - When computing after swapping A and B and transposing the N dimension is now our small dimension. With a small `MMA_N` we can more effectively tile without performing unecessary computation. +- *Layout Swapping*: When optimizing with the profiler swap `m` and `n` inputs and adjust layouts to reflect this swapping and transposing. + - For example if we have a row-major A, column-major B, and row-major D, we can swap tensors and run a kernel with: + - The left hand matrix as row-major (since B transposed is row-major) + - A right hand matrix as column-major (since A transposed is column-major) + - A column-major output (since D transposed is column-major). + +When using blockwise and groupwise GEMM we must swap the scale vector sizes when doing this optimization. If we have a 1 element scale granularity M +and a 128 element scale granularity N, we must run a kernel with a 128 element scale granularity M and a 1 element scale granularity +N. diff --git a/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu b/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu new file mode 100644 index 00000000..dd748452 --- /dev/null +++ b/examples/86_blackwell_mixed_dtype_gemm/86_blackwell_mixed_dtype.cu @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" + +#include "helper.h" + +#include "mixed_dtype_helper.cuh" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::bfloat16_t; +using QuantType = cutlass::int4b_t; +using AccumulatorType = float; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using ElementZero = MmaType; +using ElementScale = MmaType; + +// C/D matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::RowMajor; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = AccumulatorType; // Element type for internal accumulation +using ElementCompute = AccumulatorType; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using MmaTileShape = Shape<_256,_128,_128>; // (MmaTileShape_N, MmaTileShape_M, MmaTileShape_K) as A and B will be swapped +using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster +using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmMixedInputSm100; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +constexpr int ScaleGranularityN = 1; //Should be less than or equal to GEMM_N +constexpr int ScaleGranularityK = 128; //Should be less than or equal to GEMM_K +using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig; +using LayoutScale = decltype(ScaleConfig::deduce_layout_scale()); // Layout type for SFA matrix operand +LayoutScale layout_S; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// ============================================================ MIXED INPUT NO SCALES ============================================================================ + //The collective will infer that the narrow type should be upcasted to the wide type. + //We swap A and B operands to the builder here +using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + MainloopSchedule + >::CollectiveOp; + +using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopConvertOnly, + CollectiveEpilogue +>; + +using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, cute::tuple, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + MainloopSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ================================================================== +// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format. +using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, cute::tuple, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + MainloopSchedule + >::CollectiveOp; + +using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleWithZeroPoint, + CollectiveEpilogue +>; + +using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter; +// ================================================================================================================================================================= + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; + +using StrideC = typename GemmKernelScaleOnly::StrideC; +using StrideD = typename GemmKernelScaleOnly::StrideD; +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +// Scale and Zero share a stride since the layout and shapes must be the same. +using StrideS = typename cute::Stride, int64_t, int64_t>; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(MixedDtypeOptions const& options) { + + auto shape_b = cute::make_shape(options.n, options.k, options.l); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + layout_S = ScaleConfig::tile_atom_to_shape_scale(make_shape(options.n, options.k, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_S))); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(blockscale_b_coord.product()); + block_zero.reset(blockscale_b_coord.product()); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); + + if(options.verify){ + auto layout_B = make_layout(shape_b, stride_B); + auto scale_stride = layout_S.stride(); + auto layout_scale_zero = make_layout( + make_shape(size<0>(layout_S), size<1,1>(layout_S), size<2>(layout_S)), + make_stride(size<0,1>(scale_stride), size<1,1>(scale_stride), size<2>(scale_stride)) + ); //layout = (options.n, scale_k, options.l) : (_1, options.n, _0) + cudaStream_t stream = cudaStreamDefault; + cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, ScaleGranularityK, stream); + } +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +Args args_from_options(MixedDtypeOptions const& options) +{ +// Swap the A and B tensors, as well as problem shapes here. + if constexpr (KernelConversionMode == cutlass::detail::ConversionMode::DirectConvert) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), stride_B, block_A.get(), stride_A}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } + else if constexpr(KernelConversionMode == cutlass::detail::ConversionMode::ConvertAndScale) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), layout_S}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } + else if constexpr(KernelConversionMode == cutlass::detail::ConversionMode::ConvertAndScaleWithZero) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), layout_S, block_zero.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } else { + exit(-1); + } +} + +bool verify(MixedDtypeOptions const& options) { + // + // Compute reference output + // + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-2f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + return passed; +} + +/// Execute a given example GEMM computation +template +int run(MixedDtypeOptions &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + MixedDtypeResult result; + if(options.verify){ + result.passed = verify(options); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + } + else{ + result.passed = true; + std::cout << " Verification: Off " << std::endl; + } + if (!result.passed) { + exit(-1); + } + mixed_dtype_profiling(gemm, options, result); + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 Toolkit to run this example + // and must have compute capability at least 100a. + bool is_correct_cuda_version = (__CUDACC_VER_MAJOR__ >= 12) && (__CUDACC_VER_MINOR__ >= 8); + if (!is_correct_cuda_version) { + std::cerr << "Version is " << __CUDACC_VER_MINOR__ << "\n"; + std::cerr << "This example requires CUDA 12.8 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr + << "This example requires a GPU of NVIDIA's Blackwell Architecture or " + << "later (compute capability 100a or greater).\n"; + return 0; + } + + // + // Parse options + // + + MixedDtypeOptions options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + std::cout << "Running in conversion only mode." << std::endl; + run(options); + } + else if (options.mode == MixedDtypeGemmMode::ScaleOnly) { + std::cout << "Running in scale mode." << std::endl; + run(options); + } + else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + std::cout << "Running in scale and zero mode." << std::endl; + run(options); + } + else{ + std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; + } +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/86_blackwell_mixed_dtype_gemm/CMakeLists.txt b/examples/86_blackwell_mixed_dtype_gemm/CMakeLists.txt new file mode 100644 index 00000000..16e1a0cf --- /dev/null +++ b/examples/86_blackwell_mixed_dtype_gemm/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_S_TILE_SHAPE --m=256 --n=128 --k=32 --verify --iterations=0) +set(TEST_S_TILE_SHAPE_MULTIPLE_KITER --m=256 --n=128 --k=128 --verify --iterations=0) +set(TEST_S_DIFFERENT_MN --m=16384 --n=4608 --k=4608 --verify --iterations=0) +set(TEST_S_ONE_WAVE --m=1536 --n=1536 --k=32 --verify --iterations=0) # Assuming 144 SMs +set(TEST_S_2048 --m=2048 --n=2048 --k=2048 --verify --iterations=0) # Multi-wave + +if(NOT WIN32) + cutlass_example_add_executable( + 86_blackwell_mixed_dtype_gemm + 86_blackwell_mixed_dtype.cu + TEST_COMMAND_OPTIONS + TEST_S_TILE_SHAPE + TEST_S_TILE_SHAPE_MULTIPLE_KITER + TEST_S_ONE_WAVE + TEST_S_2048 + ) +endif() diff --git a/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh b/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh new file mode 100644 index 00000000..f26e6be8 --- /dev/null +++ b/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cute/tensor.hpp" + +#include +#include +#include "helper.h" + +enum MixedDtypeGemmMode { + ConvertOnly, + ScaleOnly, + ScaleWithZeroPoint +}; + +/// Command line options parsing +struct MixedDtypeOptions { + + bool help = false; + bool verify = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 1000; + int warmup = 1000; + int mode = 1; + int m = 5120, n = 4096, k = 4096; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("verify")) { + verify = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("mode", mode); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "86_blackwell_mixed_dtype_gemm\n\n" + << " Blackwell Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n" + << " --verify= Run verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "86_blackwell_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct MixedDtypeResult +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +/// Profiling Loop +template +void mixed_dtype_profiling( + Gemm& gemm, + MixedDtypeOptions const& options, + MixedDtypeResult& result) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + // Compute average setup and runtime and GFLOPs. + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + +} + +/// Helpers to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); + } + else { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + const float max_dequant_val = 4.f; + const float min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + } + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); + } else { + // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + } + return true; +} + diff --git a/examples/87_blackwell_geforce_gemm_blockwise/87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu b/examples/87_blackwell_geforce_gemm_blockwise/87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu new file mode 100644 index 00000000..8a4360a9 --- /dev/null +++ b/examples/87_blackwell_geforce_gemm_blockwise/87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu @@ -0,0 +1,518 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief An FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS. + + This example demonstrates a simple way to instantiate and run a blockwise scaling FP8 GEMM on the NVIDIA Blackwell SM120 architecture. + This kernel is optimized for the GeForce RTX 50 series GPUs. + + This kernel accepts Inputs A and B with TileMxTileK and TileNxTileK FP32 block scaling, performing scaling and accumulation every TileK elements. + Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages: + + 1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Epilogue Optimization + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + 3. Runtime scaling block size. + + Usage: + + $ ./examples/87_blackwell_geforce_gemm_blockwise/87a_blackwell_geforce_fp8_bf16_gemm_blockwise --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +#include "./utils.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; // Element Accumulator will also be our scale factor type +using ElementCompute = float; + + +// MMA and Cluster Tile Shapes +// Shape of the tile +using MmaTileShape_MNK = Shape<_128,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_1,_1,_1>; + +using ScaleConfig = decltype(cutlass::detail::sm120_trivial_blockwise_scale_config(MmaTileShape_MNK{})); + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutC, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +// Strides just iterate over scalars and have no zeros +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; +// Layouts are tiled to the problem size and the strides have zeros +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "87a_blackwell_geforce_gemm_blockwise\n\n" + << " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "87a_blackwell_geforce_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + + auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + tensor_SFA.resize(blockscale_a_coord); + tensor_SFB.resize(blockscale_b_coord); + + initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + + initialize_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025); + initialize_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, + tensor_B.device_data(), stride_B, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(Options const& options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA); + auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/87_blackwell_geforce_gemm_blockwise/87b_blackwell_geforce_fp8_bf16_gemm_groupwise.cu b/examples/87_blackwell_geforce_gemm_blockwise/87b_blackwell_geforce_fp8_bf16_gemm_groupwise.cu new file mode 100644 index 00000000..a90d24cc --- /dev/null +++ b/examples/87_blackwell_geforce_gemm_blockwise/87b_blackwell_geforce_fp8_bf16_gemm_groupwise.cu @@ -0,0 +1,539 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief An FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS. + + This example demonstrates a simple way to instantiate and run cooperative and ping-pong groupwise scaling FP8 GEMMs on the NVIDIA Blackwell SM120 architecture. + These kernels are optimized for GeForce RTX 50 series GPUs. + + The blockscaling kernels accept Inputs A and B with 1xTileK and TileNxTileK FP32 block scaling, performing scaling and accumulation every TileK elements. + The ping-pong kernel leverages a smaller tile shape to avoid register spilling for better performance. + Similar to 79a_blackwell_geforce_nvfp4_bf16_gemm, this kernel leverages: + + 1. Warp-Specialized persistent kernel design that supports both cooperative and ping-pong kernel schedule introduced in Hopper. + 2. The new SW controlled dynamic scheduler based on cluster launch control (See https://docs.nvidia.com/cuda/parallel-thread-execution). + 3. Epilogue Optimization + + Note that GeForce RTX 50 series GPUs do not support: + 1. Multicast feature of TMA load. Cluster shape has to be 1x1x1. + 2. Dynamic datatypes. + 3. Runtime scaling block size. + + Usage: + + $ ./examples/87_blackwell_geforce_gemm_blockwise/87b_blackwell_geforce_fp8_bf16_gemm_groupwise --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +#include "./utils.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; // Element Accumulator will also be our scale factor type +using ElementCompute = float; + + +// MMA and Cluster Tile Shapes +// Shape of the tile +using CooperativeMmaTileShape_MNK = Shape<_128,_128,_128>; +// Smaller tile size for pingpong schedule to avoid register spilling +using PingpongMmaTileShape_MNK = Shape<_64, _128, _128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_1,_1,_1>; + +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; +using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig; + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + +template +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutC, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +template +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + Schedule // cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120 + >::CollectiveOp; + +template +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +// We are using cooperative kernel schedule by default +using CooperativeGemm = cutlass::gemm::device::GemmUniversalAdapter< + GemmKernel>; + +// Pingpong kernel +using PingpongGemm = cutlass::gemm::device::GemmUniversalAdapter< + GemmKernel>; + +using StrideA = typename CooperativeGemm::GemmKernel::StrideA; +using StrideB = typename CooperativeGemm::GemmKernel::StrideB; +using StrideC = typename CooperativeGemm::GemmKernel::StrideC; +using StrideD = typename CooperativeGemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +// Strides just iterate over scalars and have no zeros +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; +// Layouts are tiled to the problem size and the strides have zeros +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "87b_blackwell_geforce_gemm_groupwise\n\n" + << " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "87b_blackwell_geforce_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + + auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + tensor_SFA.resize(blockscale_a_coord); + tensor_SFB.resize(blockscale_b_coord); + + initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + + initialize_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025); + initialize_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, + tensor_B.device_data(), stride_B, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA); + auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + printf("Running kernel with Cooperative kernel schedule:\n"); + run(options); + + printf("Running kernel with Pingpong kernel schedule:\n"); + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/87_blackwell_geforce_gemm_blockwise/87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise.cu b/examples/87_blackwell_geforce_gemm_blockwise/87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise.cu new file mode 100644 index 00000000..467f8145 --- /dev/null +++ b/examples/87_blackwell_geforce_gemm_blockwise/87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise.cu @@ -0,0 +1,678 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief An FP8 groupwise scaled grouped GEMM example for the NVIDIA Blackwell SM120 architecture using CUTLASS. + + This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM120 TensorOp-based warp-specialized kernel + for FP8 with per-group:1x128x128 FP32 scaling factors. + In this example, M, N, and K are fixed across groups. + As RTX 50 series GPUs do not support runtime scaling block sizes, all groups share the same block scaling size. + For this example all scheduling work is performed on the device, utilizing the device-side modification of TMA descriptors + to move between groups/problem_count (represented by groups). + https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device + + To run this example: + + $ ./examples/87_blackwell_geforce_gemm_blockwise/87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Same applies for alpha and beta values that are randomized across the different groups. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +#include "./utils.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +using ProblemShape = cutlass::gemm::GroupProblemShape>; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; // Element Accumulator will also be our scale factor type +using ElementCompute = float; + + +// MMA and Cluster Tile Shapes +// Shape of the tile +using MmaTileShape_MNK = Shape<_128,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_1,_1,_1>; + +// Scaling Factors +using ElementSF = ElementAccumulator; + +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; +using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig; + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelScheduleSm120Blockwise + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; +static_assert(cute::is_same_v); +static_assert(cute::is_same_v); + + +/// Initialization +uint64_t seed; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; + +std::vector alpha_host; +std::vector beta_host; + +using HostTensorA = cutlass::HostTensor; +using HostTensorB = cutlass::HostTensor; +using HostTensorC = cutlass::HostTensor; +using HostTensorD = cutlass::HostTensor; +using HostTensorSFA = cutlass::HostTensor; +using HostTensorSFB = cutlass::HostTensor; + +std::vector block_A; +std::vector block_B; +std::vector block_C; +std::vector block_D; +std::vector block_ref_D; +std::vector block_SFA; +std::vector block_SFB; + +cutlass::DeviceAllocation problem_sizes; +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; + +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1, groups = 10; + std::vector problem_sizes_host; + RasterOrderOptions raster_order = RasterOrderOptions::AlongN; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char, 'N'); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + + for (int i = 0; i < groups; ++i) { + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "87c_blackwell_geforce_grouped_gemm_groupwise\n\n" + << " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "87c_blackwell_geforce_grouped_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --groups=8 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * groups; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_SFA_host(options.groups); + std::vector ptr_SFB_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + for (int i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto [M, N, K] = problem; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + + auto layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + stride_A_host.push_back(stride_A); + stride_B_host.push_back(stride_B); + layout_SFA_host.push_back(layout_SFA); + layout_SFB_host.push_back(layout_SFB); + stride_C_host.push_back(stride_C); + stride_D_host.push_back(stride_D); + + block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A)))); + block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B)))); + block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C)))); + block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_SFA.push_back(HostTensorSFA(cutlass::make_Coord(size(filter_zeros(layout_SFA))))); + block_SFB.push_back(HostTensorSFB(cutlass::make_Coord(size(filter_zeros(layout_SFB))))); + block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + } + + for (int i = 0; i < options.groups; ++i) { + initialize_tensor(block_A.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(block_B.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(block_C.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2024); + initialize_tensor(block_SFA.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2025); + initialize_tensor(block_SFB.at(i).host_view(), cutlass::Distribution::Uniform, seed + 2026); + + block_A.at(i).sync_device(); + block_B.at(i).sync_device(); + block_C.at(i).sync_device(); + block_SFA.at(i).sync_device(); + block_SFB.at(i).sync_device(); + + ptr_A_host.at(i) = block_A.at(i).device_data(); + ptr_B_host.at(i) = block_B.at(i).device_data(); + ptr_C_host.at(i) = block_C.at(i).device_data(); + ptr_D_host.at(i) = block_D.at(i).device_data(); + ptr_SFA_host.at(i) = block_SFA.at(i).device_data(); + ptr_SFB_host.at(i) = block_SFB.at(i).device_data(); + + alpha_host.push_back((options.alpha == std::numeric_limits::max()) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == std::numeric_limits::max()) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(ptr_SFB_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + if (options.alpha != std::numeric_limits::max()) { + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.dAlpha = {_0{}, _0{}, 0}; + } else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + } + + if (options.beta != std::numeric_limits::max()) { + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + fusion_args.dBeta = {_0{}, _0{}, 0}; + } else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + fusion_args.dBeta = {_0{}, _0{}, 1}; + } + + arguments = { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), + ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), + ptr_SFB.get(), layout_SFB.get()}, + { + fusion_args, + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + hw_info, scheduler + }; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + bool passed = true; + + for (int i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto [M, N, K] = problem; + + auto A = cute::make_tensor(block_A.at(i).host_data(), + cute::make_layout(cute::make_shape(M, K, 1), stride_A_host.at(i))); + auto B = cute::make_tensor(block_B.at(i).host_data(), + cute::make_layout(cute::make_shape(N, K, 1), stride_B_host.at(i))); + auto C = cute::make_tensor(block_C.at(i).host_data(), + cute::make_layout(cute::make_shape(M, N, 1), stride_C_host.at(i))); + auto D = cute::make_tensor(block_ref_D.at(i).host_data(), + cute::make_layout(cute::make_shape(M, N, 1), stride_D_host.at(i))); + auto SFA = cute::make_tensor(block_SFA.at(i).host_data(), layout_SFA_host.at(i)); + auto SFB = cute::make_tensor(block_SFB.at(i).host_data(), layout_SFB_host.at(i)); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha_host.at(i); + epilogue_params.beta = beta_host.at(i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + + block_D.at(i).sync_host(); + passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view()); + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << " " << options.groups << " Groups" << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.8 or higher Toolkit for SM120 support, + // or CUDA 12.9 or higher for SM121 support. + // Must have compute capability at least 120. +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) { + std::cerr << "This example requires CUDA 12.8 or newer for SM120 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#elif defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer for SM121 support." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } +#endif + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (!(props.major == 12 && (props.minor == 0 || props.minor == 1))) { + std::cerr << "This example requires a GPU with compute capability 120a or 121a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + initialize(options); + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/87_blackwell_geforce_gemm_blockwise/CMakeLists.txt b/examples/87_blackwell_geforce_gemm_blockwise/CMakeLists.txt new file mode 100644 index 00000000..dec9e3dc --- /dev/null +++ b/examples/87_blackwell_geforce_gemm_blockwise/CMakeLists.txt @@ -0,0 +1,47 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a") +cutlass_example_add_executable( + 87a_blackwell_geforce_fp8_bf16_gemm_blockwise + 87a_blackwell_geforce_fp8_bf16_gemm_blockwise.cu +) + +cutlass_example_add_executable( + 87b_blackwell_geforce_fp8_bf16_gemm_groupwise + 87b_blackwell_geforce_fp8_bf16_gemm_groupwise.cu +) + +cutlass_example_add_executable( + 87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise + 87c_blackwell_geforce_fp8_bf16_grouped_gemm_groupwise.cu +) + +endif() diff --git a/examples/87_blackwell_geforce_gemm_blockwise/utils.h b/examples/87_blackwell_geforce_gemm_blockwise/utils.h new file mode 100644 index 00000000..72735303 --- /dev/null +++ b/examples/87_blackwell_geforce_gemm_blockwise/utils.h @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} diff --git a/examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm.cu b/examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm.cu new file mode 100644 index 00000000..d6d7879c --- /dev/null +++ b/examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm.cu @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A GEMM example using CUTLASS for the NVIDIA Blackwell SM103 architecture. + + This example demonstrates a simple way to instantiate and run a blockscaled 3xFP4 GEMM on the NVIDIA Blackwell SM103 architecture. + + Usage: + + $ ./examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + + +#include + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e2m1_t; // Element type for A matrix operand +using ElementSFA = cutlass::float_ue4m3_t; +using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e2m1_t; // Element type for A matrix operand +using ElementSFB = cutlass::float_ue4m3_t; +using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand +using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand +using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand +using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) +// Kernel functional config +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm103; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + +// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands + +// Kernel Perf config +using MmaTileShape = cute::Shape>; // MMA's tile size +using ClusterShape = cute::Shape; // Shape of the threadblocks in a cluster + +// Epilogue fusion operator +using EpilogueFusionOp = cutlass::epilogue::fusion::LinearCombination; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized1Sm, + EpilogueFusionOp + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutATag, AlignmentA, + cute::tuple, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using StrideA = typename Gemm::GemmKernel::StrideA; +using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{})); +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideB = typename Gemm::GemmKernel::StrideB; +using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{})); +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride. +using StrideC = typename Gemm::GemmKernel::StrideC; +using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{})); +using StrideD = typename Gemm::GemmKernel::StrideD; +using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{})); + +// +// Data members +// + +/// Initialization +StrideA stride_A; +LayoutA layout_A; +LayoutSFA layout_SFA; +StrideB stride_B; +LayoutB layout_B; +LayoutSFB layout_SFB; +StrideC stride_C; +LayoutC layout_C; +StrideD stride_D; +LayoutD layout_D; +uint64_t seed; + +// The HostTensors are only used for allocating memory on host and device, and transferring data between host and device +// Use cute::Tensor and cute::Layout for iterating thru the matrix elements +cutlass::HostTensor block_A; +cutlass::HostTensor block_SFA; +cutlass::HostTensor block_B; +cutlass::HostTensor block_SFB; +cutlass::HostTensor block_C; +// Output Tensor +cutlass::HostTensor block_D; +// Reference Output Tensor +cutlass::HostTensor block_reference_D; +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + int swizzle = 0; + + Options(): + help(false), + m(1024), n(1024), k(1024), + alpha(1.f), beta(0.f), + iterations(10), + swizzle(0) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("swizzle", swizzle); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "89_sm103_fp4_ultra_gemm\n\n" + << " Sm103 3xFP4 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --swizzle= Cluster rasterization swizzle\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out << "\n\nExamples:\n\n" + << "$ " << "./examples/89_sm103_fp4_ultra_gemm/89_sm103_fp4_ultra_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + // For SFA and SFB tensors layouts + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + layout_A = make_layout(make_shape(options.m, options.k, 1), stride_A); + layout_B = make_layout(make_shape(options.n, options.k, 1), stride_B); + layout_C = make_layout(make_shape(options.m, options.n, 1), stride_C); + layout_D = make_layout(make_shape(options.m, options.n, 1), stride_D); + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(options.m, options.n, options.k, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(options.m, options.n, options.k, 1)); + + block_A.reset(cutlass::make_Coord(size(layout_A))); + block_B.reset(cutlass::make_Coord(size(layout_B))); + block_C.reset(cutlass::make_Coord(size(layout_C))); + block_D.reset(cutlass::make_Coord(size(layout_D))); + block_reference_D.reset(cutlass::make_Coord(size(layout_D))); + block_SFA.reset(cutlass::make_Coord(size(filter_zeros(layout_SFA)))); + block_SFB.reset(cutlass::make_Coord(size(filter_zeros(layout_SFB)))); + + initialize_block(block_A.host_view(), seed + 2021); + initialize_block(block_B.host_view(), seed + 2022); + initialize_block(block_C.host_view(), seed + 2023); + initialize_block(block_SFA.host_view(), seed + 2024); + initialize_block(block_SFB.host_view(), seed + 2025); + + block_A.sync_device(); + block_B.sync_device(); + block_C.sync_device(); + block_SFA.sync_device(); + block_SFB.sync_device(); +} + +// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + { // Mainloop arguments + block_A.device_data(), stride_A, + block_B.device_data(), stride_B, + block_SFA.device_data(), layout_SFA, + block_SFB.device_data(), layout_SFB + }, + { // Epilogue arguments + {options.alpha, options.beta}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + } + }; + + arguments.scheduler.max_swizzle_size = options.swizzle; + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.host_data(), layout_SFB); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, // ElementAccumulator + decltype(tensor_A), // TensorA + decltype(tensor_SFA), // TensorSfA + decltype(tensor_B), // TensorB + decltype(tensor_SFB) // TensorSfB + > mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.host_data()), layout_C); + auto tensor_D = cute::make_tensor(make_iterator(block_reference_D.host_data()), layout_D); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementAccumulator, // ElementScalar + ElementAccumulator, // ElementAccumulator + ElementAccumulator, // ElementCompute + decltype(tensor_C), // TensorC + decltype(tensor_D) // TensorD + > epilogue_params{options.alpha, options.beta, tensor_C, tensor_D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Comparison + block_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(block_reference_D.host_view(), block_D.host_view()); + passed &= (cutlass::reference::host::TensorNorm(block_reference_D.host_view()) > 0); + passed &= (cutlass::reference::host::TensorNorm(block_D.host_view()) > 0); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + uint8_t* workspace = nullptr; + cudaError_t status = cudaMalloc(&workspace, workspace_size); + if (status != cudaSuccess) { + std::cerr << "Failed to allocate workspace memory: " << cudaGetErrorString(status) << std::endl; + return -1; + } + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace)); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Free workspace memory + cudaFree(workspace); + + cudaDeviceSynchronize(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace)); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.9 or higher Toolkit to run this example + // and must have compute capability at least 100. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer." << std::endl; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (!(props.major == 10 && props.minor == 3)) { + std::cerr << "This example requires a GPU of NVIDIA's Blackwell architecture (compute capability 103)." << std::endl; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/89_sm103_fp4_ultra_gemm/CMakeLists.txt b/examples/89_sm103_fp4_ultra_gemm/CMakeLists.txt new file mode 100644 index 00000000..d60baf04 --- /dev/null +++ b/examples/89_sm103_fp4_ultra_gemm/CMakeLists.txt @@ -0,0 +1,38 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 103a) + +cutlass_example_add_executable( + 89_sm103_fp4_ultra_gemm + 89_sm103_fp4_ultra_gemm.cu +) + +endif() diff --git a/examples/90_sm103_fp4_ultra_grouped_gemm/90_sm103_fp4_ultra_grouped_gemm.cu b/examples/90_sm103_fp4_ultra_grouped_gemm/90_sm103_fp4_ultra_grouped_gemm.cu new file mode 100644 index 00000000..18f592d2 --- /dev/null +++ b/examples/90_sm103_fp4_ultra_grouped_gemm/90_sm103_fp4_ultra_grouped_gemm.cu @@ -0,0 +1,1018 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Grouped GEMM example using CUTLASS 3 APIs for the NVIDIA Blackwell SM100 architecture. + + This example demonstrates an implementation of Grouped GEMM using a TMA + Blackwell SM103 TensorOp-based warp-specialized kernel + for narrow precisions (FP4) with Scale Factors (In and Out). + For this example all scheduling work is performed on the device. + The new feature showcased in this example is device-side modification of TMA descriptors + to move between groups/problem_count (represented by groups). + https://docs.nvidia.com/cuda/cuda-c-programming-guide/#encoding-a-tensor-map-on-device + + To run this example: + + $ ./examples/90_sm103_fp4_ultra_grouped_gemm/90_sm103_fp4_ultra_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/90_sm103_fp4_ultra_grouped_gemm/90_sm103_fp4_ultra_grouped_gemm --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "helper.h" +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementInput = cutlass::float_e2m1_t; // Element type for Input matrix operands +using ElementSF = cutlass::float_ue4m3_t; // Element type for SF matrix operands +using ElementC = cutlass::half_t; // Element type for C matrix operands + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = ElementInput; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = ElementInput; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 32; // Alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementD = ElementC; // Element type for D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Alignment of D matrix in units of elements (up to 16 bytes) +using ElementAccumulator = float; // Element type for internal accumulation + +// using ElementD = cutlass::float_e2m1_t; // Enable for SF Output // Element type for D matrix operands + +using ElementSFD = cutlass::float_ue4m3_t; // Element type for SF Output operands +constexpr int OutputSFVectorSize = 16; +using FusionOperation = cutlass::epilogue::fusion::LinCombEltActBlockScaleFactor< + cutlass::epilogue::thread::SiLu, + OutputSFVectorSize, + ElementD, + ElementAccumulator, + ElementSFD, + LayoutC, + ElementC>; + +// Core kernel configurations +using ArchTag = cutlass::arch::Sm103; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Runtime Cluster Shape +using ClusterShape = cute::Shape; + +// Different configs for 1SM and 2SM MMA kernel +struct MMA1SMConfig { + using MmaTileShape = cute::Shape>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm; // Epilogue to launch +}; + +struct MMA2SMConfig { + using MmaTileShape = cute::Shape>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm; // Epilogue to launch +}; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + typename MMA1SMConfig::EpilogueSchedule + // , FusionOperation // Enable for SF Output +>::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutA *, AlignmentA, + cute::tuple, LayoutB *, AlignmentB, + ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule +>::CollectiveOp; +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; +using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; +using Gemm = Gemm1SM; + +using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + typename MMA2SMConfig::MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutC *, AlignmentD, + typename MMA2SMConfig::EpilogueSchedule + // , FusionOperation // Enable for SF Output +>::CollectiveOp; +using CollectiveMainloop2SM = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutA *, AlignmentA, + cute::tuple, LayoutB *, AlignmentB, + ElementAccumulator, + typename MMA2SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue2SM::SharedStorage))>, + typename MMA2SMConfig::KernelSchedule +>::CollectiveOp; +using GemmKernel2SM = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop2SM, + CollectiveEpilogue2SM +>; +using Gemm2SM = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; +using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; +using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + OutputSFVectorSize, + cute::is_same_v ? cute::UMMA::Major::K : cute::UMMA::Major::MN + >; +using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; +using LayoutSFD = typename Sm1xxBlockScaledOutputConfig::LayoutSF; + +// Host-side allocations +std::vector stride_A_host; +std::vector stride_B_host; +std::vector layout_SFA_host; +std::vector layout_SFB_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +using HostTensorA = cutlass::HostTensor; +using HostTensorB = cutlass::HostTensor; +using HostTensorSF = cutlass::HostTensor; +using HostTensorC = cutlass::HostTensor; +using HostTensorD = cutlass::HostTensor; +std::vector block_A; +std::vector block_B; +std::vector block_SFA; +std::vector block_SFB; +std::vector block_C; +std::vector block_D; +std::vector block_SFD; +std::vector block_ref_D; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_SFA; +cutlass::DeviceAllocation ptr_SFB; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_SFD; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation layout_SFA; +cutlass::DeviceAllocation layout_SFB; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; +// A matrix wide constant value to scale the output matrix +// Avoids generating small FP4 values. +// NormConst is a single device-side constant value, its not per-batch or per-group +cutlass::DeviceAllocation norm_constant_device; + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; +// Command line options parsing +struct Options { + + bool help = false; + bool verification = true; + bool use_pdl = false; + + float alpha = FLT_MAX; + float beta = FLT_MAX; + float norm_constant = 1.0; + int iterations = 3; + int m = 1024, n = 1024, k = 512, groups = 5; + dim3 cluster_shape = dim3(2,1,1); + dim3 cluster_shape_fallback = dim3(2,1,1); + RasterOrderOptions raster_order = RasterOrderOptions::AlongN; + int max_sm_count = INT_MAX; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + if (cmd.check_cmd_line_flag("no_verif")) { + verification = false; + } + if (cmd.check_cmd_line_flag("use_pdl")) { + use_pdl = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX); + cmd.get_cmd_line_argument("beta", beta, FLT_MAX); + cmd.get_cmd_line_argument("norm_constant", norm_constant, float(1.0)); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("cluster_m", cluster_shape.x); + cmd.get_cmd_line_argument("cluster_n", cluster_shape.y); + cmd.get_cmd_line_argument("cluster_fallback_m", cluster_shape_fallback.x); + cmd.get_cmd_line_argument("cluster_fallback_n", cluster_shape_fallback.y); + cmd.get_cmd_line_argument("max_sm_count", max_sm_count, INT_MAX); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 32) + 1); + } + if (n < 1) { + n = alignment * ((rand() % 32) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 32) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "90_sm103_fp4_ultra_grouped_gemm\n\n" + << " Sm103 3xFP4 Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --norm_constant= Epilogue scalar normalization constant for the output matrix\n\n" + << " --cluster_m= and --cluster_n= Sets the X,Y dims of the preferred cluster shape\n" + << " --cluster_fallback_m= and --cluster_fallback_n= Sets the X,Y dims of the fallback cluster shape\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M)\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size\n" + << " --max_sm_count= Run kernels using only these number of SMs\n" + << " --no_verif Do not run (host-side) verification kernels\n" + << " --use_pdl Launch kernel with PDL (Programmatic Dependent Launch) enabled\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "90_sm103_fp4_ultra_grouped_gemm" << " --m=1024 --n=1024 --k=512 --groups=5 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + constexpr int bits_input = cutlass::sizeof_bits::value; + + if constexpr (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if constexpr (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + else if constexpr (bits_input <= 8) { + if constexpr (cute::is_same_v) { + scope_max = 4; + scope_min = 1; + } + else { + scope_max = 1; + scope_min = -1; + } + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + std::cout << "===> Starting memory allocation..." << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + std::cout << " Allocating problem " << i << ": M=" << M << ", N=" << N << ", K=" << K << std::endl; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + stride_A_host.push_back(stride_A); + stride_B_host.push_back(stride_B); + layout_SFA_host.push_back(layout_SFA); + layout_SFB_host.push_back(layout_SFB); + stride_C_host.push_back(stride_C); + stride_D_host.push_back(stride_D); + + block_A.push_back(HostTensorA(cutlass::make_Coord(size(layout_A)))); + block_B.push_back(HostTensorB(cutlass::make_Coord(size(layout_B)))); + block_SFA.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFA))))); + block_SFB.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFB))))); + block_C.push_back(HostTensorC(cutlass::make_Coord(size(layout_C)))); + block_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + block_SFD.push_back(HostTensorSF(cutlass::make_Coord(size(filter_zeros(layout_SFD))))); + block_ref_D.push_back(HostTensorD(cutlass::make_Coord(size(layout_D)))); + } + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + std::cout << "===> Memory allocation completed for " << options.groups << " groups" << std::endl; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + std::cout << "===> Starting initialization..." << std::endl; + uint64_t seed = 2020; + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + std::cout << " Problem sizes copied to device" << std::endl; + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_SFA_host(options.groups); + std::vector ptr_SFB_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_SFD_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + std::cout << " Initializing tensor data for " << options.groups << " groups..." << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " Initializing group " << i << "/" << options.groups << std::endl; + + initialize_block(block_A.at(i).host_view(), seed + 2021); + initialize_block(block_B.at(i).host_view(), seed + 2022); + initialize_block(block_C.at(i).host_view(), seed + 2023); + initialize_block(block_SFA.at(i).host_view(), seed + 2024); + initialize_block(block_SFB.at(i).host_view(), seed + 2025); + + block_A.at(i).sync_device(); + block_B.at(i).sync_device(); + block_C.at(i).sync_device(); + block_SFA.at(i).sync_device(); + block_SFB.at(i).sync_device(); + + ptr_A_host.at(i) = block_A.at(i).device_data(); + ptr_B_host.at(i) = block_B.at(i).device_data(); + ptr_SFA_host.at(i) = block_SFA.at(i).device_data(); + ptr_SFB_host.at(i) = block_SFB.at(i).device_data(); + ptr_C_host.at(i) = block_C.at(i).device_data(); + ptr_D_host.at(i) = block_D.at(i).device_data(); + ptr_SFD_host.at(i) = block_SFD.at(i).device_data(); + + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + std::cout << " Transferring pointers and parameters to device..." << std::endl; + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_SFA.reset(options.groups); + ptr_SFA.copy_from_host(ptr_SFA_host.data()); + + ptr_SFB.reset(options.groups); + ptr_SFB.copy_from_host(ptr_SFB_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_SFD.reset(options.groups); + ptr_SFD.copy_from_host(ptr_SFD_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + layout_SFA.reset(options.groups); + layout_SFA.copy_from_host(layout_SFA_host.data()); + + layout_SFB.reset(options.groups); + layout_SFB.copy_from_host(layout_SFB_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + + norm_constant_device.reset(1); + norm_constant_device.copy_from_host(&options.norm_constant); + + std::cout << "===> Initialization completed" << std::endl; +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename Gemm::Arguments args_from_options(Options &options, bool host_problem_shapes_available = true) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = min(cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id), options.max_sm_count); + + if (!is_static_v) { + if (size<0>(typename Gemm::GemmKernel::CollectiveMainloop::AtomThrShapeMNK{}) == 2 && + (options.cluster_shape.x < 2 || options.cluster_shape_fallback.x < 2)) { + std::cout << "Error: MMA2SMConfig kernel config needs cluster_dim.x >= 2" << std::endl; + } + hw_info.cluster_shape = options.cluster_shape; + hw_info.cluster_shape_fallback = options.cluster_shape_fallback; + } + + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + // If alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + if (options.alpha != FLT_MAX){ + // Single alpha for all groups + fusion_args.alpha = options.alpha; + fusion_args.alpha_ptr_array = nullptr; + } + else { + fusion_args.alpha = 0; + fusion_args.alpha_ptr_array = alpha_device.get(); + // Only one alpha per each group + } + if (options.beta != FLT_MAX) { + // Single beta for all groups + fusion_args.beta = options.beta; + fusion_args.beta_ptr_array = nullptr; + } + else { + fusion_args.beta = 0; + fusion_args.beta_ptr_array = beta_device.get(); + // Only one beta per each group + } + // Output Block SF + // fusion_args.block_scale_factor_ptr = ptr_SFD.get(); // Enable for SF Output + // fusion_args.norm_constant_ptr = norm_constant_device.get(); // Enable for SF Output + + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = options.raster_order; + + if (host_problem_shapes_available) { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + else { + arguments = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_SFA.get(), layout_SFA.get(), ptr_SFB.get(), layout_SFB.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, scheduler + }; + } + + return arguments; +} + +bool verify(const Options &options) { + using namespace cute; + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); + auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); + auto layout_A = make_layout(make_shape(M, K, 1), stride_A); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B); + auto layout_C = make_layout(make_shape(M, N, 1), stride_C); + auto layout_D = make_layout(make_shape(M, N, 1), stride_D); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + auto layout_SFD = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(cute::make_shape(M, N, K, 1)); + + // Create the arguments for host reference implementation + Tensor tensor_A = make_tensor(make_iterator(block_A.at(i).host_data()), layout_A); + Tensor tensor_SFA = make_tensor(block_SFA.at(i).host_data(), layout_SFA); + Tensor tensor_B = make_tensor(make_iterator(block_B.at(i).host_data()), layout_B); + Tensor tensor_SFB = make_tensor(block_SFB.at(i).host_data(), layout_SFB); + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{tensor_A, tensor_SFA, tensor_B, tensor_SFB}; + + auto tensor_C = cute::make_tensor(make_iterator(block_C.at(i).host_data()), layout_C); + auto tensor_ref_D = cute::make_tensor(make_iterator(block_ref_D.at(i).host_data()), layout_D); + + cutlass::reference::host::GettEpilogueParams< + float, float, + ElementAccumulator, ElementAccumulator, + decltype(tensor_C), decltype(tensor_ref_D) + > epilogue_params{}; + + epilogue_params.C = tensor_C; + epilogue_params.D = tensor_ref_D; + epilogue_params.alpha = alpha_host.at(i); + epilogue_params.beta = beta_host.at(i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + block_D.at(i).sync_host(); + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::host::TensorEquals(block_ref_D.at(i).host_view(), block_D.at(i).host_view()); + + if (!passed) { + std::cout << "--- Verification failed at group " << i << " ---\n"; +#if 0 + std::ostringstream ref_out, comp_out; + cutlass::TensorViewWrite(ref_out, block_ref_D.at(i).host_view()); + std::cout << "\n[Reference Tensor D]\n" << ref_out.str() << "\n"; + + cutlass::TensorViewWrite(comp_out, block_D.at(i).host_view()); + std::cout << "\n[Computed Tensor D]\n" << comp_out.str() << "\n"; +#endif + } + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + std::cout << "\n===> Starting GEMM execution with " << options.groups << " groups" << std::endl; + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + std::cout << " Creating kernel arguments..." << std::endl; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + std::cout << " Workspace size required: " << workspace_size << " bytes" << std::endl; + + // Allocate workspace memory + std::cout << " Allocating workspace memory..." << std::endl; + uint8_t* workspace = nullptr; + cudaError_t status = cudaMalloc(&workspace, workspace_size); + if (status != cudaSuccess) { + std::cerr << "ERROR: Failed to allocate workspace memory: " << cudaGetErrorString(status) << std::endl; + return -1; + } + std::cout << " Workspace memory allocated successfully" << std::endl; + + // Check if the problem size is supported or not + std::cout << " Checking if problem size is supported..." << std::endl; + auto can_impl = gemm.can_implement(arguments); + if (can_impl == cutlass::Status::kSuccess) { + std::cout << " Problem size is supported" << std::endl; + } else { + std::cout << "ERROR: Problem size is not supported: " << cutlassGetStatusString(can_impl) << std::endl; + } + CUTLASS_CHECK(can_impl); + + // Initialize CUTLASS kernel with arguments and workspace pointer + std::cout << " Initializing CUTLASS kernel..." << std::endl; + CUTLASS_CHECK(gemm.initialize(arguments, workspace)); + std::cout << " Kernel initialized successfully" << std::endl; + + // Correctness / Warmup iteration + std::cout << " Running warmup iteration..." << std::endl; + auto run_status = gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl); + if (run_status == cutlass::Status::kSuccess) { + std::cout << " Warmup iteration completed successfully" << std::endl; + } else { + std::cout << "ERROR: Warmup iteration failed: " << cutlassGetStatusString(run_status) << std::endl; + } + CUTLASS_CHECK(run_status); + + // Free workspace memory + std::cout << " Freeing workspace memory..." << std::endl; + CUDA_CHECK(cudaFree(workspace)); + + std::cout << " Synchronizing device..." << std::endl; + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + if (options.verification) { + std::cout << "\n===> Running host-side verification" << std::endl; + std::cout << " This may be very slow for large cases." << std::endl; + result.passed = verify(options); + std::cout << " Verification result: " << (result.passed ? "PASSED" : "FAILED") << std::endl; + if (!result.passed) { + std::cout << "ERROR: Verification failed, exiting" << std::endl; + exit(-1); + } + } + else { + std::cout << " Verification is turned off for this run." << std::endl; + } + + + // Run profiling loop + if (options.iterations > 0) + { + std::cout << "\n===> Running performance measurements" << std::endl; + std::cout << " Iterations: " << options.iterations << std::endl; + std::cout << " Allocating workspace for profiling..." << std::endl; + + // Re-allocate workspace for profiling + status = cudaMalloc(&workspace, workspace_size); + if (status != cudaSuccess) { + std::cerr << "ERROR: Failed to allocate workspace memory for profiling: " << cudaGetErrorString(status) << std::endl; + return -1; + } + + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace)); + CUTLASS_CHECK(gemm.run(/* stream = */ nullptr, /* cuda_adapter = */ nullptr, /* launch_with_pdl = */ options.use_pdl)); + } + timer.stop(); + + // Free profiling workspace + cudaFree(workspace); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << "\n===> Performance results:" << std::endl; + std::cout << " Total time : " << elapsed_ms << " ms" << std::endl; + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + std::cout << "\n====================================================" << std::endl; + std::cout << "CUTLASS 3.0 Grouped GEMM Example - 3xfp4 Block Scaled" << std::endl; + std::cout << "====================================================" << std::endl; + + // CUTLASS must be compiled with CUDA 12.9 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer.\n"; + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + + if (!(props.major == 10 && props.minor == 3)) { + std::cerr << "ERROR: This example requires a GPU of NVIDIA's Blackwell Architecture (compute capability 103a).\n"; + return 0; + } + + // + // Parse options + // + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + allocate(options); + initialize(options); + + // + // Evaluate CUTLASS kernels + // + + std::cout << "\n===> Running kernel with 1SM MMA config:" << std::endl; + run(options, false /*host_problem_shapes_available*/); + std::cout << "Running kernel with 2SM MMA config:" << std::endl; + run(options, false /*host_problem_shapes_available*/); +#else + std::cout << "\nERROR: CUTLASS_ARCH_MMA_SM103_SUPPORTED is not defined. This example cannot run on this system." << std::endl; +#endif + + std::cout << "\n===> Example completed successfully" << std::endl; + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt b/examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt new file mode 100644 index 00000000..2c6bd869 --- /dev/null +++ b/examples/90_sm103_fp4_ultra_grouped_gemm/CMakeLists.txt @@ -0,0 +1,42 @@ +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_RANDOM_SMALL_GROUP --groups=3 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_SMALL_GROUP --alpha=1.5 --beta=2.0 --groups=3 --iterations=1) # Random problem sizes + +if (CUTLASS_NVCC_ARCHS MATCHES 103a) + +cutlass_example_add_executable( + 90_sm103_fp4_ultra_grouped_gemm + 90_sm103_fp4_ultra_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM_SMALL_GROUP + TEST_EPILOGUE_SMALL_GROUP +) + +endif() diff --git a/examples/91_fp4_gemv/91_fp4_gemv.cu b/examples/91_fp4_gemv/91_fp4_gemv.cu new file mode 100644 index 00000000..65fb2a0f --- /dev/null +++ b/examples/91_fp4_gemv/91_fp4_gemv.cu @@ -0,0 +1,898 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include // uint64_t +#include +#include // rand(), RAND_MAX +#include // std::stoi +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" +// clang-format off +#include "cute/tensor.hpp" // FIX cute header file inclusion issue +// clang-format on + +#include "cute/arch/mma_sm100_desc.hpp" // cute::UMMA::Major +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cutlass/complex.h" // cutlass::ComplexTransform +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/detail/sm100_blockscaled_layout.hpp" // cutlass::detail::Sm1xxBlockScaledOutputConfig +#include "cutlass/epilogue/thread/linear_combination.h" // cutlass::epilogue::thread::LinearCombination +#include "cutlass/gemm/device/gemv_blockscaled.h" // cutlass::gemm::device::Gemv +#include "cutlass/gemm/kernel/gemv_blockscaled.h" // cutlass::gemm::kernel::Gemv +#include "cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h" // cutlass::epilogue::threadblock::GemvEpilogueWithScalingFactor +#include "cutlass/gemm_coord.h" // cutlass::GemmCoord +#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory +#include "cutlass/numeric_size.h" // cutlss::is_subbyte +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" // cutlass::is_same_v +#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation +#include "cutlass/util/distribution.h" // cutlass::Distribution +#include "cutlass/util/host_tensor.h" // cutlass::HostTensor +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/util/reference/host/gemm_complex.h" // cutlass::reference::host::GemmComplex +#include // cutlass::reference::host::GettBlockScalingMainloopParams +// cutlass::reference::host::GettBlockScalingEpilogueParams +// cutlass::reference::host::Gemm3x +#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals +#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes + +// Helper Functions +template +auto +make_iterator(T* ptr) +{ + return cute::recast_ptr(ptr); +} + +template +bool +initialize_tensor(cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) +{ + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } else if (bits_input <= 8) { + if constexpr (cutlass::is_same_v || + cutlass::is_same_v) { + scope_max = 4; + scope_min = 1; + } else { + scope_max = 1; + scope_min = -1; + } + } else { + scope_max = 4; + scope_min = -4; + } + + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + + else { + CUTLASS_ASSERT(false); + return false; + } + + return true; +} + +// Base class of Testbed +template < + typename Gemv_, + // The following types are more difficult to be derived from EVT + typename ElementC, typename LayoutC, typename ElementD_, + typename LayoutD, typename ElementSFD_, typename LayoutSFD, + typename ElementCompute_, int kVectorSize_> +struct TestbedGemvFp4SFDBase +{ + public: + using Gemv = Gemv_; + + using ElementA = typename Gemv::ElementA; + using ElementSFA = typename Gemv::ElementSFA; + using LayoutA = typename Gemv::LayoutA; + static_assert(cutlass::is_same_v, "only support row major matrix A"); + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFA should be FP8 type"); + + using ElementB = typename Gemv::ElementB; + using ElementSFB = typename Gemv::ElementSFB; + using LayoutB = cutlass::layout::ColumnMajor; + static_assert(cutlass::is_same_v, "only support ElementA ElementB of same type"); + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFB should be FP8 type"); + + static_assert(cutlass::is_same_v, "only support col major output D"); + + using ElementD = ElementD_; + static_assert(cutlass::is_same_v, "only support col major output D"); + + using ElementSFD = ElementSFD_; + static_assert(cutlass::is_same_v, "only support col major output SFD"); + static_assert(cutlass::sizeof_bits::value, "only support 8 bit SFD"); + + using ElementAccumulator = typename Gemv::ElementAccumulator; + using ElementCompute = ElementCompute_; + static_assert(cutlass::is_same_v, "only support fp32 epi compute"); + + static constexpr int kVectorSize = kVectorSize_; + static_assert(kVectorSize == 16, "only support vs 16"); + + // SFD Config + static constexpr bool kIsKMajorSFD = cutlass::is_same_v; + using Sm1xxBlockScaledOutputConfig= + cutlass::detail::Sm1xxBlockScaledOutputConfig; + using Blk_MN_Output = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF_Output = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + + // SFA SFB Config + using Sm100BlockScaledInputConfig = cutlass::detail::Sm1xxBlockScaledConfig; + using Blk_MN_Input = typename Sm100BlockScaledInputConfig::Blk_MN; + using Blk_SF_Input = typename Sm100BlockScaledInputConfig::Blk_SF; + using SfAtom_Input = typename Sm100BlockScaledInputConfig::SfAtom; + + public: + TestbedGemvFp4SFDBase(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_SFA_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_SFB_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_SFD_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2023) + : init_A(init_A_) + , init_B(init_B_) + , init_C(init_C_) + , init_D(init_D_) + , init_SFA(init_SFA_) + , init_SFB(init_SFB_) + , init_SFD(init_SFD_) + , seed(seed_) + { + } + + bool initialize(cutlass::MatrixCoord problem_size, int32_t batch_count) + { + const int32_t gemm_m = problem_size.row(); + const int32_t gemm_k = problem_size.column(); + const int32_t gemm_n = 1; + const int32_t gemm_batch = batch_count; + + // Resize Config SFA/SFB + auto k_blks_input = cutlass::ceil_div(gemm_k, cute::size<1>(shape(SfAtom_Input{}))); + auto m_blks_input = cutlass::ceil_div(gemm_m, Blk_MN_Input{}); + auto n_blks_input = cutlass::ceil_div(gemm_n, Blk_MN_Input{}); + + auto sfa_coord = cutlass::make_Coord(m_blks_input * Blk_MN_Input{} * gemm_batch, k_blks_input * Blk_SF_Input{}); + auto sfb_coord = cutlass::make_Coord(n_blks_input * Blk_MN_Input{} * gemm_batch, k_blks_input * Blk_SF_Input{}); + + auto sfa_resize_layout = + cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, typename LayoutA::Stride{}); + auto sfb_resize_layout = + cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, typename LayoutB::Stride{}); + + // Use the same SFD layout generation as reference for tensor creation + using ProblemShapeType = cute::Shape; + auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch}; + + // Generate the same layout as reference uses + auto sfd_layout = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL); + + // Extract size from the generated layout and create coordinate + auto sfd_size = cute::size(cute::filter_zeros(sfd_layout)); + auto sfd_coord = cutlass::make_Coord(sfd_size, 1); // Linear layout for HostTensor + + auto sfd_resize_layout = + cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, typename LayoutSFD::Stride{}); + + // Resize Host + this->reference_D.resize({gemm_batch * gemm_m, 1}); // D col major vector + this->reference_SFD.resize(sfd_coord, sfd_resize_layout); + + if (initialize_tensor(this->reference_D.host_view(), this->init_D, this->seed + 7) == false) { + printf("initialize_tensor() REF D failed\n"); + return false; + } + if (initialize_tensor(this->reference_SFD.host_view(), this->init_SFD, this->seed + 9) == false) { + printf("initialize_tensor() REF SFD failed\n"); + return false; + } + + // Resize A/B/C/D + this->tensor_A.resize({gemm_batch * gemm_m, gemm_k}); // A row major + this->tensor_B.resize({gemm_batch * gemm_k, 1}); // B col major vector + this->tensor_C.resize({gemm_batch * gemm_m, 1}); // C col major vector + this->tensor_D.resize({gemm_batch * gemm_m, 1}); // D col major vector + this->tensor_SFA.resize(sfa_coord, sfa_resize_layout); + this->tensor_SFB.resize(sfb_coord, sfb_resize_layout); + this->tensor_SFD.resize(sfd_coord, sfd_resize_layout); + + // Fill A/B/C + if (initialize_tensor(this->tensor_A.host_view(), this->init_A, this->seed + 1) == false) { + printf("initialize_tensor() A failed\n"); + return false; + } + if (initialize_tensor(this->tensor_B.host_view(), this->init_B, this->seed + 2) == false) { + printf("initialize_tensor() B failed\n"); + return false; + } + if (initialize_tensor(this->tensor_C.host_view(), this->init_C, this->seed + 3) == false) { + printf("initialize_tensor() C failed\n"); + return false; + } + + // Fill SFA/SFB + if (initialize_tensor(this->tensor_SFA.host_view(), this->init_SFA, this->seed + 4) == false) { + printf("initialize_tensor() SFA failed\n"); + return false; + } + if (initialize_tensor(this->tensor_SFB.host_view(), this->init_SFB, this->seed + 5) == false) { + printf("initialize_tensor() SFB failed\n"); + return false; + } + + // Fill D/SFD + if (initialize_tensor(this->tensor_D.host_view(), this->init_D, this->seed + 6) == false) { + printf("initialize_tensor() D failed\n"); + return false; + } + if (initialize_tensor(this->tensor_SFD.host_view(), this->init_SFD, this->seed + 8) == false) { + printf("initialize_tensor() SFD failed\n"); + return false; + } + + // Copy A/B/C from host to device + this->tensor_A.sync_device(); + this->tensor_B.sync_device(); + this->tensor_C.sync_device(); + this->tensor_D.sync_device(); + this->tensor_SFA.sync_device(); + this->tensor_SFB.sync_device(); + this->tensor_SFD.sync_device(); + + // SFD initialization is different. + // Init referenceSFD on host first, and then copy data to tensorSFD device side. + // This ensures tensorSFD and referenceSFD to have same data, + // otherwise the "bubbles" due to SFD layouts can lead to false negative sanity check. + cutlass::device_memory::copy_to_host(this->reference_SFD.host_data(), this->tensor_SFD.device_data(), sfd_size); + + return true; + } + + bool compare_reference() + { + // device -> host + this->tensor_D.sync_host(); + + bool passed = true; + + // Check + passed = cutlass::reference::host::TensorEquals(this->reference_D.host_view(), this->tensor_D.host_view()); + if (passed == false) { + printf("gemm_m: %d, gemm_k: %d, ", this->tensor_A.host_view().extent(0), this->tensor_A.host_view().extent(1)); + printf("tensorD mismatch\n"); + return false; + } + + this->tensor_SFD.sync_host(); + + passed = cutlass::reference::host::TensorEquals(this->reference_SFD.host_view(), this->tensor_SFD.host_view()); + if (passed == false) { + printf("gemm_m: %d, gemm_k: %d, ", this->tensor_A.host_view().extent(0), this->tensor_A.host_view().extent(1)); + printf("tensorSFD mismatch\n"); + return false; + } + + return passed; + } + + bool run_reference(cutlass::MatrixCoord problem_size, + int32_t batch_count, + ElementCompute alpha, + ElementCompute beta, + float epilogue_st) + { + const int32_t gemm_m = problem_size.row(); + const int32_t gemm_k = problem_size.column(); + const int32_t gemm_n = 1; + const int32_t gemm_batch = batch_count; + + // Run reference blockscale GETT + using ProblemShapeType = cute::Shape; + auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch}; + auto SfD = make_tensor(make_iterator(this->reference_SFD.host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(gemm_m, gemm_k, gemm_batch)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(gemm_n, gemm_k, gemm_batch)); + StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(gemm_m, gemm_n, gemm_batch)); + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(gemm_m, gemm_n, gemm_batch)); + + auto A = make_tensor(make_iterator(this->tensor_A.host_data()), + cute::make_layout(cute::make_shape(gemm_m, gemm_k, gemm_batch), stride_a)); + auto B = make_tensor(make_iterator(this->tensor_B.host_data()), + cute::make_layout(cute::make_shape(gemm_n, gemm_k, gemm_batch), stride_b)); + + auto C = cute::make_tensor(make_iterator(this->tensor_C.host_data()), + cute::make_layout(cute::make_shape(gemm_m, gemm_n, gemm_batch), stride_c)); + auto D = cute::make_tensor(make_iterator(this->reference_D.host_data()), + cute::make_layout(cute::make_shape(gemm_m, gemm_n, gemm_batch), stride_d)); + + auto layout_sfa = Sm100BlockScaledInputConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + auto layout_sfb = Sm100BlockScaledInputConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + auto SfA = make_tensor(this->tensor_SFA.host_data(), layout_sfa); + auto SfB = make_tensor(this->tensor_SFB.host_data(), layout_sfb); + + // Internally scale factor of mainloop will be disabled when ElementA/B == ElementSFA/B. + typename cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + typename cutlass::reference::host::GettBlockScalingEpilogueParams, // OutputVectorSize + cutlass::reference::host::SfStrategy::SfDGen + > + epilogue_params{alpha, beta, C, D, SfD, epilogue_st}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + return true; + } + + virtual typename Gemv::Arguments get_arguments( + cutlass::MatrixCoord problem_size, int32_t batch_count, + float epilogue_st, ElementCompute alpha, ElementCompute beta) = 0; + + bool run_gemv(cutlass::MatrixCoord problem_size, + int32_t batch_count, + ElementCompute alpha, + ElementCompute beta, + [[maybe_unused]] float epilogue_st, + bool is_profiling, + int kIterations) + { + + // Not support batch input for testing + const int32_t gemm_m = problem_size.row(); + const int32_t gemm_k = problem_size.column(); + [[maybe_unused]] const int32_t gemm_n = 1; + [[maybe_unused]] const int32_t gemm_batch = batch_count; + + Gemv gemv_op; + typename Gemv::Arguments arguments = this->get_arguments( + problem_size, batch_count, epilogue_st, alpha, beta + ); + + cutlass::Status status = gemv_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + printf("can_implement() failed\n"); + return false; + } + + size_t workspace_size = Gemv::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemv_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + printf("initialize() failed\n"); + return false; + } + + if (not is_profiling) { + status = gemv_op(); + } + // profiling + else { + cudaError_t result; + cudaEvent_t events[2]; + + for (cudaEvent_t &evt : events) { + result = cudaEventCreate(&evt); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + } + + // warmup + status = gemv_op(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Device execution failed on warmup." << std::endl; + return false; + } + + result = cudaEventRecord(events[0]); + + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + for (int iter_i = 0; iter_i < kIterations; ++iter_i) { + status = gemv_op(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Device execution failed." << std::endl; + return false; + } + } + + result = cudaEventRecord(events[1]); + + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + float elapsed_ms = 0; + result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); + + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + for (cudaEvent_t &evt : events) { + result = cudaEventDestroy(evt); + if (result != cudaSuccess) { + std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + } + + int64_t flops = int64_t(gemm_m) * gemm_n * gemm_k * 2; + int64_t bytes = cutlass::bits_to_bytes(int64_t(cute::sizeof_bits_v) * int64_t(gemm_m) * int64_t(gemm_k)) + + cutlass::bits_to_bytes(int64_t(cute::sizeof_bits_v) * int64_t(gemm_k) * int64_t(gemm_n)) + + cutlass::bits_to_bytes(int64_t(cute::sizeof_bits_v) * int64_t(gemm_m) * int64_t(gemm_n)) + + cutlass::bits_to_bytes(int64_t(cute::sizeof_bits_v) * int64_t(gemm_m) * int64_t(gemm_k) / int64_t(kVectorSize)) + + cutlass::bits_to_bytes(int64_t(cute::sizeof_bits_v) * int64_t(gemm_k) * int64_t(gemm_n) / int64_t(kVectorSize)) + + cutlass::bits_to_bytes(int64_t(cute::sizeof_bits_v) * int64_t(gemm_m) * int64_t(gemm_n) / int64_t(kVectorSize)); + + double gflops_per_second = double(flops) * kIterations * gemm_batch / double(elapsed_ms / 1000.0f) / double(1.0e9); + double gbytes_per_second = double(bytes) * kIterations * gemm_batch / double(elapsed_ms / 1000.0f) / double(1 << 30); + double elapsed_ms_per_iter = double(elapsed_ms) / kIterations; + + std::cout << " Problem: " + << gemm_m << "-by-" << gemm_n << "-by-" << gemm_k + << ", batch size: " << gemm_batch + << std::endl; + std::cout << " Runtime: " << elapsed_ms_per_iter << " ms" << std::endl; + std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; + std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl; + + } + + if (status != cutlass::Status::kSuccess) { + printf("gemv exec failed\n"); + return false; + } + + return true; + } + + bool run_and_verify(cutlass::MatrixCoord problem_size, + int32_t batch_count, + ElementCompute alpha, + ElementCompute beta, + float epilogue_st) + { + + // Initialize Data + if (this->initialize(problem_size, batch_count) == false) { + return false; + } + + // Run GEMV kernel + if (this->run_gemv(problem_size, batch_count, alpha, beta, epilogue_st, false /*is_profiling*/, 1) == false) { + return false; + } + + // Run Reference Kernel + if (this->run_reference(problem_size, batch_count, alpha, beta, epilogue_st) == false) { + printf("run_reference() failed\n"); + return false; + } + + // Verify + if (this->compare_reference() == false) { + printf("compare_reference() failed\n"); + return false; + } + + return true; + } + + bool profile(cutlass::MatrixCoord problem_size, + int32_t batch_count, + ElementCompute alpha, + ElementCompute beta, + float epilogue_st, + int kIterations = 10) + { + // Initialize Data + if (this->initialize(problem_size, batch_count) == false) { + return false; + } + + // Profile GEMV kernel + if (this->run_gemv(problem_size, batch_count, alpha, beta, epilogue_st, true /*is_profiling*/, kIterations) == false) { + return false; + } + + return true; + } + + public: + // Data Storage + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_SFA; + + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_SFB; + + cutlass::HostTensor tensor_C; + + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_SFD; + + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_SFD; + + // Data Init Setting + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_D; + cutlass::Distribution::Kind init_SFA; + cutlass::Distribution::Kind init_SFB; + cutlass::Distribution::Kind init_SFD; + uint64_t seed; +}; + +template +struct TestbedGemvFp4SFD : public TestbedGemvFp4SFDBase< + Gemv_, + typename Gemv_::ElementC, + typename Gemv_::EpilogueOutputOp::LayoutOutput, + typename Gemv_::EpilogueOutputOp::ElementD, + typename Gemv_::EpilogueOutputOp::LayoutOutput, + typename Gemv_::EpilogueOutputOp::ElementSFD, + typename Gemv_::EpilogueOutputOp::LayoutSFD, + typename Gemv_::EpilogueOutputOp::ElementCompute, + Gemv_::EpilogueOutputOp::kVectorSize +> { + using Base = TestbedGemvFp4SFDBase< + Gemv_, + typename Gemv_::ElementC, + typename Gemv_::EpilogueOutputOp::LayoutOutput, + typename Gemv_::EpilogueOutputOp::ElementD, + typename Gemv_::EpilogueOutputOp::LayoutOutput, + typename Gemv_::EpilogueOutputOp::ElementSFD, + typename Gemv_::EpilogueOutputOp::LayoutSFD, + typename Gemv_::EpilogueOutputOp::ElementCompute, + Gemv_::EpilogueOutputOp::kVectorSize + >; + + using Base::Base; + using Gemv = Gemv_; + using ElementCompute = typename Base::ElementCompute; + using SfAtom_Input = typename Base::SfAtom_Input; + using Blk_MN_Input = typename Base::Blk_MN_Input; + using Blk_SF_Input = typename Base::Blk_SF_Input; + + static constexpr int kVectorSize = Base::kVectorSize; + + typename Gemv::Arguments get_arguments( + cutlass::MatrixCoord problem_size, + int32_t batch_count, float epilogue_st, + ElementCompute alpha, ElementCompute beta) override { + + const int32_t gemm_m = problem_size.row(); + const int32_t gemm_k = problem_size.column(); + [[maybe_unused]] const int32_t gemm_n = 1; + [[maybe_unused]] const int32_t gemm_batch = batch_count; + + auto k_blks_input = cutlass::ceil_div(gemm_k, cute::size<1>(shape(SfAtom_Input{}))); + auto m_blks_input = cutlass::ceil_div(gemm_m, Blk_MN_Input{}); + auto n_blks_input = cutlass::ceil_div(gemm_n, Blk_MN_Input{}); + + int batch_stride_SFA = m_blks_input * Blk_MN_Input{} * k_blks_input * Blk_SF_Input{}; + int batch_stride_SFB = n_blks_input * Blk_MN_Input{} * k_blks_input * Blk_SF_Input{}; + + // Use the same SFD layout generation as reference to get correct batch stride + using ProblemShapeType = cute::Shape; + auto problem_shape_MNKL = ProblemShapeType{gemm_m, gemm_n, gemm_k, gemm_batch}; + + // Generate the same layout as reference uses + using Sm1xxBlockScaledOutputConfig = typename Base::Sm1xxBlockScaledOutputConfig; + auto sfd_layout = Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL); + + // Calculate batch stride from the generated layout + // Extract the batch stride from the 3rd dimension stride + // The stride<2> gives us the stride for the batch dimension + auto batch_stride_tuple = cute::stride<2>(sfd_layout); // This returns (_0, 8192) + int batch_stride_SFD = static_cast(cute::get<1>(batch_stride_tuple)); // Extract the 8192 part + + // Initialize GEMV kernel + typename Gemv::Arguments arguments{ + problem_size, // problem_size + batch_count, // batch_count + typename Gemv::EpilogueOutputOp::Params{ + this->tensor_D.device_ref(), // tensor_d + this->tensor_SFD.device_data(), // scale_factor_d_ptr + alpha, // alpha + beta, // beta + epilogue_st, // st + batch_stride_SFD, // batch_stride_sfd + gemm_m // stride_d + }, + this->tensor_A.device_ref(), // ref_A + this->tensor_B.device_data(), // ptr_B + this->tensor_C.device_data(), // ptr_C + this->tensor_D.device_data(), // ptr_D + this->tensor_SFA.device_data(), // ptr_SFA + this->tensor_SFB.device_data(), // ptr_SFB + gemm_k, // stride_A + gemm_m * gemm_k, // batch_stride_A + gemm_k, // batch_stride_B + gemm_m, // batch_stride_C + gemm_m, // batch_stride_D + batch_stride_SFA, // batch_stride_SFA + batch_stride_SFB, // batch_stride_SFB + batch_stride_SFD // batch_stride_SFD + }; + + return arguments; + } +}; + +struct Options { + bool help = false; + + int m = 4096; + int k = 2048; + int n = 1; + int batch = 1; + + float alpha = 1.0f; + float beta = 0.0f; + float epilogue_st = -1.0f; // sentinel for random + + bool profiling = true; + int iterations = 10; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("batch", batch); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("epilogue_st", epilogue_st); + cmd.get_cmd_line_argument("profiling", profiling); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "91_fp4_gemv\n\n" + << " FP4 GEMV with block-scaled inputs and outputs.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --batch= Sets the batch count of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --epilogue_st= Epilogue ST value\n\n" + << " --profiling= Whether to run profiling\n\n" + << " --iterations= Number of profiling iterations to perform\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "91_fp4_gemv" << " --m=4096 --k=2048 --batch=1 \n\n"; + + return out; + } +}; + +bool +run_fp4_gemv_device(Options const& options) +{ + CUTLASS_ASSERT(options.n == 1); + + using ElementA = cutlass::float_e2m1_t; + using ElementSFA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + + using ElementB = cutlass::float_e2m1_t; + using ElementSFB = cutlass::float_e4m3_t; + + using ElementC = cutlass::float_e2m1_t; + + using ElementD = cutlass::float_e2m1_t; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementSFD = cutlass::float_e4m3_t; + // Indicate SF is computed along col dim. Does NOT indicate actual layout of SFD + using LayoutSFD = cutlass::layout::ColumnMajor; + + using ElementAccumulatorMainloop = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + + ElementCompute alpha{options.alpha}; + ElementCompute beta{options.beta}; + // Must be a positive number. + const float epilogue_st = options.epilogue_st < 0.f ? + static_cast(rand()) / (static_cast(RAND_MAX / 5)) : + options.epilogue_st; + + static constexpr int kVectorSize = 16; + static constexpr int kElementsPerAccess = 128 / cutlass::sizeof_bits::value; + + using ThreadShape = cutlass::gemm::GemmShape<16, 8>; + static_assert(kVectorSize == ThreadShape::kM, "vector size and thread in row should be equal"); + + // Construct Epilogue + using EpilogueOp = typename cutlass::epilogue::threadblock::GemvEpilogueWithScalingFactor; + + // Construct Mainloop + using Gemv = cutlass::gemm::device::GemvBlockScaled< + cutlass::gemm::kernel:: + GemvBlockScaled>; + + TestbedGemvFp4SFD testbed; + + bool pass = true; + + if (options.profiling) { + pass = testbed.profile(cutlass::MatrixCoord{options.m, options.k}, options.batch, alpha, beta, epilogue_st, options.iterations); + } + else { + pass = testbed.run_and_verify(cutlass::MatrixCoord{options.m, options.k}, options.batch, alpha, beta, epilogue_st); + } + + return pass; +} + +int +main(int argc, char const** argv) +{ +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Run verification + Options verification_options = options; + verification_options.profiling = false; + + bool passed = run_fp4_gemv_device(verification_options); + if (passed == false) { + printf("test fail\n"); + return 1; + } else { + printf("test pass\n"); + } + + + if (options.profiling) { + // Start profiling + printf("\nProfiling...\n"); + passed = run_fp4_gemv_device(options); + if (passed == false) { + printf("profiling fail\n"); + return 1; + } else { + printf("profiling completed\n"); + } + + } + + return 0; +#else + std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM100_SUPPORTED is defined.\n"; + return 0; +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +} diff --git a/examples/91_fp4_gemv/CMakeLists.txt b/examples/91_fp4_gemv/CMakeLists.txt new file mode 100644 index 00000000..c7dd884b --- /dev/null +++ b/examples/91_fp4_gemv/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if (NOT MSVC) + +cutlass_example_add_executable( + 91_fp4_gemv + 91_fp4_gemv.cu +) + +endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 978105ea..b46fbbda 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -163,7 +163,12 @@ foreach(EXAMPLE 82_blackwell_distributed_gemm 83_blackwell_sparse_gemm 84_blackwell_narrow_precision_sparse_gemm + 86_blackwell_mixed_dtype_gemm + 87_blackwell_geforce_gemm_blockwise 88_hopper_fmha + 89_sm103_fp4_ultra_gemm + 90_sm103_fp4_ultra_grouped_gemm + 91_fp4_gemv ) add_subdirectory(${EXAMPLE}) diff --git a/examples/README.md b/examples/README.md index 3f79df9a..4765125f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,5 +1,14 @@ # CUTLASS - Programming Examples +> [!IMPORTANT] +> ### ⚠️ **Not for Benchmarking!** ⚠️ +> +> These examples are designed **solely for demonstrating CUTLASS functionality** and may **NOT optimized for performance benchmarking**. +> +> **For accurate performance measurements**, please use the **[CUTLASS Profiler](../tools/profiler/)** instead (recommended) or manually auto-tune the example, if unavailable via the profiler. +> + + * [00_basic_gemm](00_basic_gemm/) launches a basic GEMM with single precision inputs and outputs diff --git a/examples/common/dist_gemm_helpers.h b/examples/common/dist_gemm_helpers.h index ef258e69..35f05442 100644 --- a/examples/common/dist_gemm_helpers.h +++ b/examples/common/dist_gemm_helpers.h @@ -39,14 +39,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #include #include -#include +#include CUDA_STD_HEADER(atomic) #include "cute/layout.hpp" #include "cute/tensor.hpp" -#include "cutlass/cutlass.h" #include "cutlass/cuda_host_adapter.hpp" diff --git a/examples/cute/tutorial/blackwell/01_mma_sm100.cu b/examples/cute/tutorial/blackwell/01_mma_sm100.cu index d2d5a068..8f11a279 100644 --- a/examples/cute/tutorial/blackwell/01_mma_sm100.cu +++ b/examples/cute/tutorial/blackwell/01_mma_sm100.cu @@ -452,8 +452,8 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, dim3 dimBlock(128); dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); - dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), - round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + dim3 dimGrid(size(ceil_div(Gemm_M, bM * size<1>(cluster_layout_vmnk))) * dimCluster.x, + size(ceil_div(Gemm_N, bN * size<2>(cluster_layout_vmnk))) * dimCluster.y); int smemBytes = sizeof(SMEMStorage); auto* kernel_ptr = &gemm_device(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); - dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), - round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + dim3 dimGrid(size(ceil_div(Gemm_M, bM * size<1>(cluster_layout_vmnk))) * dimCluster.x, + size(ceil_div(Gemm_N, bN * size<2>(cluster_layout_vmnk))) * dimCluster.y); int smemBytes = sizeof(SMEMStorage); auto* kernel_ptr = &gemm_device(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); - dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), - round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + dim3 dimGrid(size(ceil_div(Gemm_M, bM * size<1>(cluster_layout_vmnk))) * dimCluster.x, + size(ceil_div(Gemm_N, bN * size<2>(cluster_layout_vmnk))) * dimCluster.y); int smemBytes = sizeof(SMEMStorage); auto* kernel_ptr = &gemm_device(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); dim3 dimGrid(size(ceil_div(Gemm_M, bM * size<1>(cluster_layout_vmnk))) * dimCluster.x, size(ceil_div(Gemm_N, bN * size<2>(cluster_layout_vmnk))) * dimCluster.y); + int smemBytes = sizeof(SMEMStorage); auto* kernel_ptr = &gemm_device(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape)); - dim3 dimGrid(round_up(size(ceil_div(Gemm_M, bM)), dimCluster.x), - round_up(size(ceil_div(Gemm_N, bN)), dimCluster.y)); + dim3 dimGrid(size(ceil_div(Gemm_M, bM * size<1>(cluster_layout_vmnk))) * dimCluster.x, + size(ceil_div(Gemm_N, bN * size<2>(cluster_layout_vmnk))) * dimCluster.y); int smemBytes = sizeof(SMEMStorage); auto* kernel_ptr = &gemm_device= 80:\n", - " plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)\n", - " plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32\n", + " plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)\n", + " plan.math_operation = cutlass_cppgen.MathOperation.multiply_add_fast_f32\n", "\n", " # Create input/output tensors in FP32\n", " A, B = [np.ones((128, 128)).astype(np.float32) for _ in range(2)]\n", @@ -433,9 +433,9 @@ "\n", "# FP8 is supported through the CUTLASS Python interface on SM90 and higher\n", "if device_cc() >= 90:\n", - " plan = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n", - " layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,\n", - " layout_C=cutlass.LayoutType.ColumnMajor)\n", + " plan = cutlass_cppgen.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n", + " layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor,\n", + " layout_C=cutlass_cppgen.LayoutType.ColumnMajor)\n", "\n", " # Create input/output tensors in FP8\n", " A, B = [torch.ones((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n", diff --git a/examples/python/deprecated/01_epilogue.ipynb b/examples/python/deprecated/01_epilogue.ipynb index 97663f50..f5196d44 100644 --- a/examples/python/deprecated/01_epilogue.ipynb +++ b/examples/python/deprecated/01_epilogue.ipynb @@ -68,7 +68,7 @@ "source": [ "import numpy as np\n", "\n", - "import cutlass\n", + "import cutlass_cppgen\n", "\n", "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", "# omit this information.\n", @@ -112,7 +112,7 @@ "metadata": {}, "outputs": [], "source": [ - "plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n", + "plan = cutlass_cppgen.op.Gemm(element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor)\n", "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" ] }, diff --git a/examples/python/deprecated/02_pytorch_extension_grouped_gemm.ipynb b/examples/python/deprecated/02_pytorch_extension_grouped_gemm.ipynb index 86c86fb6..c3069c21 100644 --- a/examples/python/deprecated/02_pytorch_extension_grouped_gemm.ipynb +++ b/examples/python/deprecated/02_pytorch_extension_grouped_gemm.ipynb @@ -75,7 +75,7 @@ "\n", "## Declaring a grouped GEMM via the CUTLASS Python interface\n", "A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n", - "simply calls `cutlass.op.GroupedGemm`." + "simply calls `cutlass_cppgen.op.GroupedGemm`." ] }, { @@ -85,11 +85,11 @@ "metadata": {}, "outputs": [], "source": [ - "import cutlass\n", + "import cutlass_cppgen\n", "import torch\n", "\n", "dtype = torch.float16\n", - "plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)" + "plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor)" ] }, { @@ -174,7 +174,7 @@ "outputs": [], "source": [ "op = plan.construct()\n", - "grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)" + "grouped_gemm = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)" ] }, { @@ -182,7 +182,7 @@ "id": "c8ca3991", "metadata": {}, "source": [ - "The `cutlass.emit.pytorch` function emits:\n", + "The `cutlass_cppgen.emit.pytorch` function emits:\n", "* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n", "* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n", "* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n", diff --git a/examples/python/deprecated/03_basic_conv2d.ipynb b/examples/python/deprecated/03_basic_conv2d.ipynb index 09ebd7bd..aa41997b 100644 --- a/examples/python/deprecated/03_basic_conv2d.ipynb +++ b/examples/python/deprecated/03_basic_conv2d.ipynb @@ -62,7 +62,7 @@ "import torch\n", "import random\n", "\n", - "import cutlass\n", + "import cutlass_cppgen\n", "\n", "# This controls whether the C++ GEMM declaration will be printed at each step. \n", "# Set to `false` to omit this information.\n", @@ -80,7 +80,7 @@ "dilation = (1, 1)\n", "\n", "# Compute the output size [N, P, Q, K]\n", - "N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)\n", + "N, P, Q, K = cutlass_cppgen.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)\n", "\n", "dtype = torch.float16\n", "type_A = torch.float16\n", @@ -111,7 +111,7 @@ "source": [ "## Declaring and running a Conv2d Fprop\n", "\n", - "We first show you how to run a Conv2d in the forward propagation. To get started, one only needs to provide the tensors declared above to the `cutlass.op.Conv2dFprop` call. This sets up a default Conv2d fprop operation for the given device on which you are running. \n", + "We first show you how to run a Conv2d in the forward propagation. To get started, one only needs to provide the tensors declared above to the `cutlass_cppgen.op.Conv2dFprop` call. This sets up a default Conv2d fprop operation for the given device on which you are running. \n", "\n", "Assuming that we are runing on SM80, the default is a Conv2d that leverages FP16 Tensor Core operations.\n", "\n", @@ -125,7 +125,7 @@ "outputs": [], "source": [ "# Specifying `element_accumulator` is not required if it is the same as `element`\n", - "plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)\n", + "plan = cutlass_cppgen.Conv2dFprop(element=dtype, element_accumulator=torch.float32)\n", "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)" ] }, @@ -133,7 +133,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "There are many other ways to construct a plan from `cutlass.op.Conv2dFprop` (e.g., by specifying the types of each operand, by providing representative tensors as input). For more details on these, see the documentation in the `cutlass.op.Conv2dFprop` constructor.\n", + "There are many other ways to construct a plan from `cutlass_cppgen.op.Conv2dFprop` (e.g., by specifying the types of each operand, by providing representative tensors as input). For more details on these, see the documentation in the `cutlass_cppgen.op.Conv2dFprop` constructor.\n", "\n", "We then compare the output to running the Conv2d using PyTorch. PyTorch use NCHW layout by default, so permutations are required." ] @@ -200,7 +200,7 @@ "metadata": {}, "outputs": [], "source": [ - "plan_dgrad = cutlass.Conv2dDgrad(element=dtype, element_accumulator=torch.float32)\n", + "plan_dgrad = cutlass_cppgen.Conv2dDgrad(element=dtype, element_accumulator=torch.float32)\n", "plan_dgrad.run(grad_output, weight, tensor_C_dgrad, grad_input, stride, padding, dilation, alpha, beta, print_module=print_module)\n", "\n", "grad_input_torch = alpha * torch.nn.grad.conv2d_input(\n", @@ -225,7 +225,7 @@ "metadata": {}, "outputs": [], "source": [ - "plan_wgrad = cutlass.Conv2dWgrad(element=dtype, element_accumulator=torch.float32)\n", + "plan_wgrad = cutlass_cppgen.Conv2dWgrad(element=dtype, element_accumulator=torch.float32)\n", "plan_wgrad.run(grad_output, input, tensor_C_wgrad, grad_weight, stride, padding, dilation, alpha, beta, print_module=print_module)\n", "\n", "grad_weight_torch = alpha * torch.nn.grad.conv2d_weight(\n", diff --git a/examples/python/deprecated/04_epilogue_visitor.ipynb b/examples/python/deprecated/04_epilogue_visitor.ipynb index cf66cd24..6ba68aad 100644 --- a/examples/python/deprecated/04_epilogue_visitor.ipynb +++ b/examples/python/deprecated/04_epilogue_visitor.ipynb @@ -67,17 +67,17 @@ "outputs": [], "source": [ "import torch\n", - "import cutlass\n", - "from cutlass.epilogue import relu\n", - "from cutlass import Tensor as FakeTensor\n", - "from cutlass.utils.profiler import CUDAEventProfiler\n", + "import cutlass_cppgen\n", + "from cutlass_cppgen.epilogue import relu\n", + "from cutlass_cppgen import Tensor as FakeTensor\n", + "from cutlass_cppgen.utils.profiler import CUDAEventProfiler\n", "\n", "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", "# omit this information.\n", "print_module = True\n", "\n", "# The Epilogue Visitor feature currently only works for SM80 and 90\n", - "from cutlass.backend.utils.device import device_cc\n", + "from cutlass_cppgen.backend.utils.device import device_cc\n", "if device_cc() not in [80, 90]:\n", " import sys\n", " sys.exit()\n", @@ -99,7 +99,7 @@ "tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", "tensor_D = torch.zeros_like(tensor_C)\n", "\n", - "plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32)" + "plan = cutlass_cppgen.op.Gemm(element=torch.float16, layout=cutlass_cppgen.LayoutType.RowMajor, element_accumulator=torch.float32)" ] }, { @@ -115,7 +115,7 @@ "\n", "The example tensors is a dictionary with tensor names as keys and reference tensors as values. The reference tensors can be `float`, `torch.Tensor`, `numpy.ndarray`, or our `FakeTensor`. They provides the shape and data type information of the inputs and outputs of the epilogue.\n", "\n", - "The epilogue can be generated simply through `cutlass.evt.trace(, )`." + "The epilogue can be generated simply through `cutlass_cppgen.evt.trace(, )`." ] }, { @@ -139,7 +139,7 @@ "bias = torch.ceil(torch.empty(size=(m, 1), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", "tensor_F = torch.zeros_like(tensor_D)\n", "examples_tensors = {\n", - " \"accum\": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),\n", + " \"accum\": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass_cppgen.LayoutType.RowMajor),\n", " \"alpha\": alpha,\n", " \"C\": tensor_C,\n", " \"beta\": beta,\n", @@ -150,7 +150,7 @@ "}\n", "\n", "# Trace the epilogue visitor\n", - "epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)" + "epilogue_visitor = cutlass_cppgen.epilogue.trace(example_epilogue, examples_tensors)" ] }, { diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 1da05ad4..66fc49c0 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -254,11 +254,11 @@ copy(AutoVectorizingCopyWithAssumedAlignment const&, if constexpr (common_elem > 1) { constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); - constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); - if constexpr ((vec_bits % 8) == 0) + if constexpr ((vec_bits % 8) == 0 && sizeof_bits_v < Int{}) { - // If more than one element vectorizes to 8bits or more, then recast and copy + // If more than one element vectorizes to a multiple of 8bits that is larger than the value_type, then recast and copy using VecType = uint_bit_t; // Recast diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index 8ec8ffb2..7ce05f30 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -54,13 +54,15 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUTE_ARCH_TMA_SM90_ENABLED # define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED # define CUTE_ARCH_STSM_SM90_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) # define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED # define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED # define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED @@ -68,11 +70,12 @@ # define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) # define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED)) # define CUTE_ARCH_TMA_SM90_ENABLED # define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED # define CUTE_ARCH_STSM_SM90_ENABLED @@ -83,32 +86,59 @@ # define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) # define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUTE_ARCH_TMA_SM90_ENABLED # define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED # define CUTE_ARCH_STSM_SM90_ENABLED #endif +// SM110 specific configs +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED)) +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +#endif + #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) # define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED #endif #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) # define CUTE_ARCH_TCGEN05_TMEM_ENABLED #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) # define CUTE_ARCH_TMA_SM100_ENABLED #endif @@ -120,12 +150,13 @@ # define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUTE_ARCH_MMA_SM120_ENABLED # define CUTE_ARCH_TMA_SM120_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) # if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) # define CUTE_ARCH_F8F6F4_MMA_ENABLED # define CUTE_ARCH_MXF8F6F4_MMA_ENABLED @@ -134,7 +165,16 @@ # endif #endif -#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) +# define CUTE_ARCH_F8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED +# endif +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED # define CUTE_ARCH_TCGEN05_TMEM_ENABLED @@ -149,14 +189,16 @@ # define CUTE_ARCH_TMA_SM100_ENABLED #endif -#if defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUTE_ARCH_LDSM_SM100A_ENABLED # define CUTE_ARCH_STSM_SM100A_ENABLED #endif #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) # define CUTE_ARCH_LOAD256_SM100A_ENABLED # define CUTE_ARCH_STORE256_SM100A_ENABLED @@ -168,3 +210,7 @@ #define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif +#if defined(CUTLASS_ARCH_MMA_SM103_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED +#endif + diff --git a/include/cute/arch/copy_sm100.hpp b/include/cute/arch/copy_sm100.hpp index aa969afe..f8a9a67b 100644 --- a/include/cute/arch/copy_sm100.hpp +++ b/include/cute/arch/copy_sm100.hpp @@ -41,6 +41,51 @@ namespace cute { //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Global Memory Load and Store PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_LOAD_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint256_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint256_t const& gmem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { + #if defined(CUTE_ARCH_LOAD256_SM100A_ENABLED) + asm volatile("ld.global.L1::no_allocate.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "l"(&gmem_addr) ); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use LOAD.256 without CUTE_ARCH_LOAD256_SM100A_ENABLED."); + #endif + } +}; + +struct SM100_STORE_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint256_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint256_t& gmem_addr) + { + #if defined(CUTE_ARCH_STORE256_SM100A_ENABLED) + asm volatile("st.global.L1::no_allocate.v8.f32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n" + :: "l"(&gmem_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7)); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use stg.256 without CUTE_ARCH_STORE256_SM100A_ENABLED."); + #endif + } +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// // // LDSM PTX definitions diff --git a/include/cute/arch/copy_sm100_tma.hpp b/include/cute/arch/copy_sm100_tma.hpp index f69cbffd..ef6cfe01 100644 --- a/include/cute/arch/copy_sm100_tma.hpp +++ b/include/cute/arch/copy_sm100_tma.hpp @@ -37,6 +37,8 @@ #include #include +#include "cutlass/arch/synclog.hpp" + namespace cute { diff --git a/include/cute/arch/copy_sm75.hpp b/include/cute/arch/copy_sm75.hpp index 0c34bc73..a6cb2387 100644 --- a/include/cute/arch/copy_sm75.hpp +++ b/include/cute/arch/copy_sm75.hpp @@ -41,11 +41,13 @@ // * https://reviews.llvm.org/D121666 // * https://reviews.llvm.org/D126846 #define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15) + #define CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75 (__clang_major__ >= 15) #endif #if defined(__NVCC__) || defined(__CUDACC_RTC__) // ldmatrix PTX instruction added in CUDA 10.2+ #define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) + #define CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) #endif #if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED) @@ -60,12 +62,19 @@ #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 #endif -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) - #define CUTE_ARCH_MOVM_SM75_ACTIVATED 1 -#else - #define CUTE_ARCH_MOVM_SM75_ACTIVATED 0 +#if ! defined(CUTE_ARCH_MOVM_SM75_SUPPORTED) + #define CUTE_ARCH_MOVM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75) #endif +#if ! defined(CUTE_ARCH_MOVM_SM75_ENABLED) + #define CUTE_ARCH_MOVM_SM75_ENABLED (CUTE_ARCH_MOVM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_MOVM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_MOVM_SM75_ACTIVATED 1 +#endif + + namespace cute { @@ -207,7 +216,6 @@ struct SM75_U32x1_MOVM_T #endif } }; - // // Legacy LDSM interfaces that aren't very useful // diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 9b9819ce..63bfb563 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -265,7 +265,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { inline CUtensorMapFloatOOBfill to_CUtensorMapFloatOOBfill(OOBFill const& t) { switch(t) { - default: throw std::runtime_error("Unknown OOBFill!"); + default: throw std::runtime_error("Unknown OOBFill!"); case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; } diff --git a/include/cute/arch/mma_sm100_desc.hpp b/include/cute/arch/mma_sm100_desc.hpp index f15108a4..d255c41f 100644 --- a/include/cute/arch/mma_sm100_desc.hpp +++ b/include/cute/arch/mma_sm100_desc.hpp @@ -459,7 +459,7 @@ union InstrDescriptorBlockScaled scale_format_ : 1, // bit [23,24) : 0=E4M3, 1=E8M0 m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) a_sf_id_ : 2, // bit [29,31) : Matrix A Scale Factor ID - : 1; // + k_size_ : 1; // bit [31,32) : MMA-K Dim. MXF8F6F4Format: 0=[dense: K32, sparse: K64]. S8Format: 0=[dense: K32, sparse: invalid]. MXF4Format: 0=[dense: K64, sparse: K128], 1=[dense: K96, sparse: invalid]. }; // Decay to a uint32_t diff --git a/include/cute/arch/mma_sm100_umma.hpp b/include/cute/arch/mma_sm100_umma.hpp index f754e266..25b504dc 100644 --- a/include/cute/arch/mma_sm100_umma.hpp +++ b/include/cute/arch/mma_sm100_umma.hpp @@ -46,10 +46,8 @@ template +struct SM103_MXF4_ULTRA_SS_VS +{ + static_assert(M == 128, "MMA M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; + + +template +struct SM103_MXF4_ULTRA_2x1SM_SS_VS +{ + static_assert(M == 128 || M == 256, "MMA M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_2x1SM_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; +} // namespace SM103 + } // end namespace cute diff --git a/include/cute/arch/mma_sm89.hpp b/include/cute/arch/mma_sm89.hpp index e810ce42..b6c44094 100644 --- a/include/cute/arch/mma_sm89.hpp +++ b/include/cute/arch/mma_sm89.hpp @@ -292,4 +292,5 @@ struct SM89_16x8x32_F16E5M2E5M2F16_TN #endif } }; + } // namespace cute diff --git a/include/cute/arch/mma_sm90_gmma_sparse.hpp b/include/cute/arch/mma_sm90_gmma_sparse.hpp index 453dbb7b..cd97a3dd 100644 --- a/include/cute/arch/mma_sm90_gmma_sparse.hpp +++ b/include/cute/arch/mma_sm90_gmma_sparse.hpp @@ -34,6 +34,8 @@ #include // CUTE_HOST_DEVICE #include // GMMA::Major, etc. +#include "cutlass/arch/synclog.hpp" + namespace cute { namespace SM90::GMMA::SPARSE { diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp index 594149d4..216d739e 100644 --- a/include/cute/atom/copy_traits_sm100.hpp +++ b/include/cute/atom/copy_traits_sm100.hpp @@ -43,6 +43,36 @@ namespace cute { +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + template <> struct Copy_Traits { diff --git a/include/cute/atom/copy_traits_sm100_tma.hpp b/include/cute/atom/copy_traits_sm100_tma.hpp index 0212db11..8621e8a9 100644 --- a/include/cute/atom/copy_traits_sm100_tma.hpp +++ b/include/cute/atom/copy_traits_sm100_tma.hpp @@ -135,6 +135,13 @@ struct Copy_Traits uint64_t*, // smem mbarrier uint64_t // cache hint > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; ////////////////////////////////////////////////////////////////////////////// @@ -223,6 +230,13 @@ struct Copy_Traits uint16_t, // multicast mask uint64_t // cache hint > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; //////////////////////////////////// diff --git a/include/cute/atom/copy_traits_sm75.hpp b/include/cute/atom/copy_traits_sm75.hpp index fdd00081..c43b7483 100644 --- a/include/cute/atom/copy_traits_sm75.hpp +++ b/include/cute/atom/copy_traits_sm75.hpp @@ -156,5 +156,4 @@ struct Copy_Traits // Reference map from (thr,val) to bit using RefLayout = DstLayout; }; - } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 209a8448..3be30af7 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -156,6 +156,13 @@ struct Copy_Traits copy_unpack(Copy_Traits const& traits, Tensor const& src, Tensor & dst) = delete; + + // Construct with updated TMA descriptor only (no barrier change) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {*new_tma_desc, aux_params_}; + } }; // The executable SM90_TMA_LOAD with tma_desc and tma_mbar @@ -181,6 +188,13 @@ struct Copy_Traits CUTE_HOST_DEVICE Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) : opargs_(desc, mbar, cache) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; // The prefetch for SM90_TMA_LOAD with tma_desc @@ -199,10 +213,22 @@ struct Copy_Traits tuple const opargs_; // Construct with any other Traits' TMA Desc - template + template CUTE_HOST_DEVICE - Copy_Traits(Copy_Traits const& traits) - : opargs_({&traits.tma_desc_}) {} + Copy_Traits(OtherTraits const& traits) + : opargs_({traits.get_tma_descriptor()}) {} + + // Construct directly with a TMA descriptor pointer + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc) + : opargs_({desc}) {} + + // Build a new Prefetch traits with a different TMA descriptor pointer + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } template @@ -312,6 +338,13 @@ struct Copy_Traits CUTE_HOST_DEVICE Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint) : opargs_(desc, mbar, mask, hint) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } }; ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index 2e69c7bb..dd4d15e7 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -2639,10 +2639,10 @@ struct MMA_Traits <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS supports types with leq 8bit types"); static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || - (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), - "SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ - or a multiple of 16 between 16 and 256 for M=128."); + static_assert(((b_major == UMMA::Major::K) && ((N % 8 == 0) && (8 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 16 == 0) && (16 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 when B is K major. \ + SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 16 between 16 and 256 when B is MN major."); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; using FrgTypeC = UMMA::tmem_frg_1sm; @@ -3051,14 +3051,16 @@ struct MMA_Traits, cute::integral_constant> { - using ValTypeD = c_type; using ValTypeA = a_type; using ValTypeB = b_type; using ValTypeC = c_type; static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types"); static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); - static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(((b_major == UMMA::Major::K) && ((N % 16 == 0) && (16 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 32 == 0) && (32 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256 when B is K major. \ + SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256 when B is MN major."); using FrgTypeA = UMMA::smem_desc; using FrgTypeB = UMMA::smem_desc; @@ -3879,4 +3881,152 @@ struct MMA_Traits using CLayout = Layout>; }; +namespace SM103 { + // Common mma_unpack for all MMA_Ops in cute::SM103 +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& zA, + Tensor const& zB, + Tensor const& C) + { + auto [A, next_A, SFA] = unzip_tensor(zA); + auto [B, next_B, SFB] = unzip_tensor(zB); + + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_next_a = next_A[0]; + uint64_t desc_b = B[0]; + uint64_t desc_next_b = next_B[0]; + + auto desc_a_temp = reinterpret_cast(desc_a); + auto desc_next_a_temp = reinterpret_cast(desc_next_a); + desc_a_temp.lbo_mode_ = 1; + desc_a_temp.leading_byte_offset_ = desc_next_a_temp.start_address_; + + auto desc_b_temp = reinterpret_cast(desc_b); + auto desc_next_b_temp = reinterpret_cast(desc_next_b); + desc_b_temp.lbo_mode_ = 1; + desc_b_temp.leading_byte_offset_ = desc_next_b_temp.start_address_; + + uint32_t tmem_c = raw_pointer_cast(D.data()); + UMMA::InstrDescriptorBlockScaled instr_desc = traits.idesc_; + instr_desc.k_size_ = 1; + auto tsfa_addr = raw_pointer_cast(SFA.data()); + auto tsfb_addr = raw_pointer_cast(SFB.data()); + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(instr_desc, tsfa_addr, tsfb_addr); + // print("a: "); print(A); print("\n"); + // print("b: "); print(B); print("\n"); + + MMA_Op::fma(reinterpret_cast(desc_a_temp), reinterpret_cast(desc_b_temp), tmem_c, uint32_t(traits.accumulate_), idesc, tsfa_addr, tsfb_addr); + } +} // end namespace SM103 + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + } // end namespace cute diff --git a/include/cute/atom/mma_traits_sm89.hpp b/include/cute/atom/mma_traits_sm89.hpp index d438fd3c..f56aab4c 100644 --- a/include/cute/atom/mma_traits_sm89.hpp +++ b/include/cute/atom/mma_traits_sm89.hpp @@ -67,7 +67,7 @@ struct MMA_Traits { }; template <> -struct MMA_Traits +struct MMA_Traits : MMA_Traits { using ValTypeD = float; using ValTypeA = float_e4m3_t; @@ -129,5 +129,4 @@ struct MMA_Traits using ValTypeC = cutlass::half_t; }; - } // end namespace cute diff --git a/include/cute/atom/partitioner.hpp b/include/cute/atom/partitioner.hpp index 75a55ccf..b15de6e2 100644 --- a/include/cute/atom/partitioner.hpp +++ b/include/cute/atom/partitioner.hpp @@ -31,8 +31,9 @@ #pragma once +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp index a431fc4a..31606e45 100644 --- a/include/cute/container/array.hpp +++ b/include/cute/container/array.hpp @@ -391,9 +391,9 @@ cute::array reverse(cute::array const& t) // // Specialize tuple-related functionality for cute::array // - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(tuple) #else #include #endif diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index 38da7ace..be6410b3 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -555,9 +555,9 @@ void fill(array_subbyte& a, T const& value) // // Specialize tuple-related functionality for cute::array_subbyte // - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(tuple) #else #include #endif diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index dfffbe25..ff9498eb 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -77,9 +77,9 @@ find(type_list const&) noexcept { // // Specialize tuple-related functionality for cute::type_list // - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(tuple) #else #include #endif diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 3844d187..cb161369 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1474,49 +1474,33 @@ domain_distribute(ShapeA const& a, ShapeB const& b) // Kernel (Nullspace) of a Layout // -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -nullspace_seq(Stride const& stride, seq) -{ - if constexpr (NextI == rank_v) { - return seq{}; - } else - if constexpr (is_constant<0, decltype(get(stride))>::value) { - return detail::nullspace_seq(stride, seq{}); - } else { - return detail::nullspace_seq(stride, seq{}); - } - - CUTE_GCC_UNREACHABLE; -} - -} // end namespace detail - -// -// Build the nullspace of a layout -// @result A layout @a result such that -// size(@a result) == size(@a layout) / size(filter(@a layout)) -// @a layout(@a result(i)) == 0 for all i < size(@a result) -// - +/** Return a layout that represents the nullspace of @a layout + * @post @a layout(@a result(i)) == 0 for all i < size(@a result) + * @post nullspace(@a result) == Layout<_1,_0>{} + * @post size(@a result) == size(@a layout) / size(filter(@a layout)) + */ template CUTE_HOST_DEVICE constexpr auto nullspace(Layout const& layout) { - auto flat_layout = flatten(layout); + [[maybe_unused]] auto flat_stride = flatten(layout.stride()); - [[maybe_unused]] auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{}); + // Select all indices corresponding to stride-0s + auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, + [&](auto init, auto i){ + if constexpr (is_constant_v<0, decltype(get(flat_stride))>) { return append(init, i); } + else { return init; } + CUTE_GCC_UNREACHABLE; + }); - if constexpr (iseq.size() == 0) { + if constexpr (tuple_size::value == 0) { return Layout<_1,_0>{}; // Empty case, nothing found } else { // Generate the corresponding new strides and construct - auto rstride = compact_major(flat_layout.shape()); - return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + auto flat_shape = flatten(layout.shape()); + auto rstride = compact_major(flat_shape); + return make_layout(unwrap(transform(iseq, [&](auto i) { return get(flat_shape); })), unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); } diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 9c811340..016ac5b6 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -469,7 +469,9 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) template CUTE_HOST_DEVICE void print(ScaledBasis const& e) { - print(e.value()); (void(printf("@%d", Ns)), ...); + print(e.value()); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(printf("@%d", Ns)), 0)); } #if !defined(__CUDACC_RTC__) @@ -482,7 +484,9 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { - os << e.value(); (void(os << "@" << Ns), ...); + os << e.value(); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(os << "@" << Ns),0)); return os; } #endif diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index 71464e72..cfef8cc5 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -29,9 +29,9 @@ * **************************************************************************************************/ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif @@ -85,6 +85,8 @@ using CUTE_STL_NAMESPACE::uint16_t; using CUTE_STL_NAMESPACE::uint32_t; using CUTE_STL_NAMESPACE::uint64_t; using cutlass::uint128_t; +using cutlass::uint256_t; + template struct uint_bit; template <> struct uint_bit< 1> { using type = uint1_t; }; template <> struct uint_bit< 2> { using type = uint2_t; }; @@ -95,6 +97,8 @@ template <> struct uint_bit< 16> { using type = uint16_t; }; template <> struct uint_bit< 32> { using type = uint32_t; }; template <> struct uint_bit< 64> { using type = uint64_t; }; template <> struct uint_bit<128> { using type = cutlass::uint128_t; }; +template <> struct uint_bit<256> { using type = cutlass::uint256_t; }; + template using uint_bit_t = typename uint_bit::type; diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp index 0104c31f..952d156e 100644 --- a/include/cute/numeric/integral_ratio.hpp +++ b/include/cute/numeric/integral_ratio.hpp @@ -225,6 +225,27 @@ operator==(C, R) { return {}; } +template +CUTE_HOST_DEVICE constexpr +bool_constant::num * R::den < R::num * R::den> +operator<(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num < c * R::den> +operator<(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::den < R::num> +operator<(C, R) { + return {}; +} + /////////////////////// // Special functions // /////////////////////// diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index ef1ca18e..d49d7100 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -165,6 +165,15 @@ get_nonswizzle_portion(Layout const& slayout) return slayout; } +// Return the codomain size of a Swizzled ComposedLayout +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout,Offset,LayoutB> const& layout) +{ + return cosize(layout.layout_b()); +} + // // Slice a Swizzled ComposedLayout // diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index 007d8c03..3d24cf45 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -761,12 +761,12 @@ recast(Tensor&& tensor) using OldType = typename remove_cvref_t::element_type; using NewType = copy_cv_t; - auto old_layout = tensor.layout(); - auto new_layout = recast_layout(old_layout); - if constexpr (is_same::value) { - return tensor; + return make_tensor(static_cast(tensor).data(), tensor.layout()); } else { + auto old_layout = tensor.layout(); + auto new_layout = recast_layout(old_layout); + // If this is an upcast of a normal Layout with static negative strides, then offset as well if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index 34cc5ca9..685babff 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -29,13 +29,13 @@ * **************************************************************************************************/ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include -#include -#include -#include -#include +#include CUDA_STD_HEADER(type_traits) +#include CUDA_STD_HEADER(utility) +#include CUDA_STD_HEADER(cstddef) +#include CUDA_STD_HEADER(cstdint) +#include CUDA_STD_HEADER(limits) #else #include #include // tuple_size, tuple_element diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index c9c636a0..5c3bd8a7 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -109,6 +109,10 @@ struct Sm120 { static int const kMinComputeCapability = 120; }; +struct Sm103 { + static int const kMinComputeCapability = 103; +}; + /// Triggers a breakpoint on the device CUTLASS_DEVICE void device_breakpoint() { diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index 3d5ec10b..2b0d4bb6 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -46,11 +46,13 @@ #endif -#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED)) #define CUTLASS_ARCH_TCGEN_ENABLED 1 #endif -#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED)) +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED)) #define CUTLASS_ARCH_TCGEN_ENABLED 1 #endif @@ -386,6 +388,7 @@ public: // CUTLASS_HOST_DEVICE static void init(ValueType const* smem_ptr, uint32_t arrive_count) { + CUTLASS_ASSERT(arrive_count != 0 && "Arrive count must be non-zero"); #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h index 60be8d72..873e6437 100644 --- a/include/cutlass/arch/config.h +++ b/include/cutlass/arch/config.h @@ -128,6 +128,27 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +// SM110 and SM110a only on 13.0 and above +#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 13 || (__CUDACC_VER_MAJOR__ == 13 && __CUDACC_VER_MINOR__ >= 0)) + #define CUTLASS_ARCH_MMA_SM110_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM110_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1100) + #define CUTLASS_ARCH_MMA_SM110_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM110_ALL)) + #define CUTLASS_ARCH_MMA_SM110A_ENABLED 1 + #endif + + // SM110f + #if (__CUDACC_VER_MAJOR__ > 13 || (__CUDACC_VER_MAJOR__ == 13 && __CUDACC_VER_MINOR__ >= 0)) + #define CUTLASS_ARCH_MMA_SM110F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) && CUDA_ARCH_FAMILY(1100)) + #define CUTLASS_ARCH_MMA_SM110F_ENABLED CUTLASS_ARCH_MMA_SM110F_SUPPORTED + #endif + #endif +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// // SM120 and SM120a @@ -151,10 +172,56 @@ #endif #endif +// SM103 and SM103a +#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM103_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM103_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1030) + #define CUTLASS_ARCH_MMA_SM103_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM103_ALL)) + #define CUTLASS_ARCH_MMA_SM103A_ENABLED 1 + #endif + + // SM103f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM103F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) && CUDA_ARCH_FAMILY(1030)) + #define CUTLASS_ARCH_MMA_SM103F_ENABLED CUTLASS_ARCH_MMA_SM103F_SUPPORTED + #endif + #endif +#endif + +// SM121 and SM121a +#if !CUTLASS_CLANG_CUDA && (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM121_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM121_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 1210) + #define CUTLASS_ARCH_MMA_SM121_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) &&\ + (defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL(1210))) + #define CUTLASS_ARCH_MMA_SM121A_ENABLED 1 + #endif + + // SM121f + #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) + #define CUTLASS_ARCH_MMA_SM121F_SUPPORTED 1 + #endif + + #if (!defined(CUTLASS_ARCH_MMA_SM121F_ENABLED) && CUDA_ARCH_FAMILY(1210)) + #define CUTLASS_ARCH_MMA_SM121F_ENABLED CUTLASS_ARCH_MMA_SM121F_SUPPORTED + #endif + #endif +#endif + #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUTLASS_ARCH_CLC_ENABLED #endif diff --git a/include/cutlass/arch/grid_dependency_control.h b/include/cutlass/arch/grid_dependency_control.h index f1e0200f..e5e99c85 100644 --- a/include/cutlass/arch/grid_dependency_control.h +++ b/include/cutlass/arch/grid_dependency_control.h @@ -62,8 +62,12 @@ (defined(__CUDA_ARCH_FEAT_SM100_ALL) || CUDA_ARCH_FAMILY(1000))) || \ (__CUDA_ARCH__ == 1010 &&\ (defined(__CUDA_ARCH_FEAT_SM101_ALL) || CUDA_ARCH_FAMILY(1010))) || \ + (__CUDA_ARCH__ == 1030 &&\ + (defined(__CUDA_ARCH_FEAT_SM103_ALL) || CUDA_ARCH_FAMILY(1030))) || \ (__CUDA_ARCH__ == 1200 &&\ - (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))))) + (defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_FAMILY(1200))) || \ + (__CUDA_ARCH__ == 1210 &&\ + (defined(__CUDA_ARCH_FEAT_SM121_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210))))) #define CUTLASS_GDC_ENABLED #endif #endif diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index 4e812935..b91a198b 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -40,6 +40,7 @@ #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/synclog.hpp" #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) #define CUDA_CP_ASYNC_ACTIVATED 1 diff --git a/include/cutlass/arch/mma_sm100.h b/include/cutlass/arch/mma_sm100.h index 46fb31f6..2863f2d2 100644 --- a/include/cutlass/arch/mma_sm100.h +++ b/include/cutlass/arch/mma_sm100.h @@ -33,8 +33,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/arch/mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sm70.h b/include/cutlass/arch/mma_sm70.h index e4889a21..6acdcfac 100644 --- a/include/cutlass/arch/mma_sm70.h +++ b/include/cutlass/arch/mma_sm70.h @@ -32,8 +32,8 @@ \brief Matrix multiply */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 120b116b..c71ea076 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -33,8 +33,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/arch/wmma.h" diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index d89974fc..22cd87d6 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -33,10 +33,9 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + #include "mma.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/arch/mma_sm89.h b/include/cutlass/arch/mma_sm89.h index a4a8b1cb..4bcd9bc1 100644 --- a/include/cutlass/arch/mma_sm89.h +++ b/include/cutlass/arch/mma_sm89.h @@ -34,10 +34,9 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + #include "mma.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h index b1314a56..b135c864 100644 --- a/include/cutlass/arch/mma_sm90.h +++ b/include/cutlass/arch/mma_sm90.h @@ -33,8 +33,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sparse_sm80.h b/include/cutlass/arch/mma_sparse_sm80.h index 187ccc17..e4ca91a1 100644 --- a/include/cutlass/arch/mma_sparse_sm80.h +++ b/include/cutlass/arch/mma_sparse_sm80.h @@ -34,8 +34,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sparse_sm89.h b/include/cutlass/arch/mma_sparse_sm89.h index 27c40dc4..6adca255 100644 --- a/include/cutlass/arch/mma_sparse_sm89.h +++ b/include/cutlass/arch/mma_sparse_sm89.h @@ -34,8 +34,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index a65ee328..93dd37d3 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -36,13 +36,20 @@ #pragma once #include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif #ifndef CUDA_CTA_RECONFIG_ACTIVATED #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ (__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \ || (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \ || (__CUDA_ARCH__ == 1010 && defined(__CUDA_ARCH_FEAT_SM101_ALL)) \ + || (__CUDA_ARCH__ == 1030 && defined(__CUDA_ARCH_FEAT_SM103_ALL)) \ || (__CUDA_ARCH__ == 1200 && defined(__CUDA_ARCH_FEAT_SM120_ALL)) \ + || (__CUDA_ARCH__ == 1210 && defined(__CUDA_ARCH_FEAT_SM121_ALL)) \ ) #define CUDA_CTA_RECONFIG_ACTIVATED 1 #endif @@ -50,7 +57,9 @@ #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ (__CUDA_ARCH__ == 1000 && CUDA_ARCH_FAMILY(1000)) \ || (__CUDA_ARCH__ == 1010 && CUDA_ARCH_FAMILY(1010)) \ + || (__CUDA_ARCH__ == 1030 && CUDA_ARCH_FAMILY(1030)) \ || (__CUDA_ARCH__ == 1200 && CUDA_ARCH_FAMILY(1200)) \ + || (__CUDA_ARCH__ == 1210 && CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) \ ) #define CUDA_CTA_RECONFIG_ACTIVATED 1 #endif diff --git a/include/cutlass/arch/synclog.hpp b/include/cutlass/arch/synclog.hpp index b9819838..5567fe56 100644 --- a/include/cutlass/arch/synclog.hpp +++ b/include/cutlass/arch/synclog.hpp @@ -35,9 +35,9 @@ #pragma once #include "cutlass/detail/helper_macros.hpp" - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif @@ -120,44 +120,34 @@ constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + constexpr bool synclog_enable_cluster_barrier_wait = true; constexpr uint32_t synclog_header_cluster_barrier_wait = 6; -constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_barrier_test_wait = true; constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; -constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 5; - +constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 3; constexpr bool synclog_enable_cluster_barrier_try_wait = true; constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; -constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; -constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 5; - +constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 3; constexpr bool synclog_enable_cluster_barrier_arrive = true; constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; -constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 3; - +constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 1; constexpr bool synclog_enable_cluster_barrier_invalidate = true; constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; -constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 3; - +constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 1; constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; -constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; -constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 6; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 4; constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; -constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 4; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 2; constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; -constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 6; - +constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 4; constexpr bool synclog_enable_fence_barrier_init = true; constexpr uint32_t synclog_header_fence_barrier_init = 16; constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; @@ -228,12 +218,11 @@ constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; constexpr bool synclog_enable_cpasync_barrier_arrive = true; constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; -constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 3; - +constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 1; CUTLASS_DEVICE bool synclog_condition_emit() { #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) - return threadIdx.x%NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + return threadIdx.x % NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; #else return 0; @@ -272,17 +261,6 @@ void synclog_print_prefix(char const* header, uint32_t at) { #endif } -CUTLASS_DEVICE -uint64_t synclog_mbarrier_bits(uint32_t smem_addr) { - uint64_t bits = 0; - asm volatile ( - "mbarrier.inval.shared::cta.b64 [%1];\n" - "ld.shared::cta.b64 %0, [%1];\n" - : "=l"(bits) : "r"(smem_addr) - ); - return bits; -} - CUTLASS_DEVICE void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { CUTLASS_UNUSED(hi); @@ -429,14 +407,11 @@ void synclog_emit_cluster_barrier_wait( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_wait) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = phase; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 3] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -453,15 +428,12 @@ void synclog_emit_cluster_barrier_test_wait( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_test_wait) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = phase; to[synclog_length_prefix + 2] = pred; - to[synclog_length_prefix + 3] = bits; - to[synclog_length_prefix + 4] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -478,14 +450,11 @@ void synclog_emit_cluster_barrier_try_wait( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_try_wait) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = phase; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 3] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -502,15 +471,12 @@ void synclog_emit_cluster_barrier_arrive_cluster( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = cta_id; to[synclog_length_prefix + 2] = pred; - to[synclog_length_prefix + 3] = bits; - to[synclog_length_prefix + 4] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -526,13 +492,10 @@ void synclog_emit_cluster_barrier_arrive( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_arrive) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -546,13 +509,10 @@ void synclog_emit_cluster_barrier_invalidate( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_barrier_invalidate) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -567,14 +527,11 @@ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = transaction_bytes; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 3] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -592,7 +549,6 @@ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line); @@ -600,8 +556,6 @@ void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( to[synclog_length_prefix + 1] = transaction_bytes; to[synclog_length_prefix + 2] = cta_id; to[synclog_length_prefix + 3] = pred; - to[synclog_length_prefix + 4] = bits; - to[synclog_length_prefix + 5] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -619,14 +573,11 @@ void synclog_emit_cluster_transaction_barrier_expect_transaction( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); to[synclog_length_prefix + 0] = smem_addr; to[synclog_length_prefix + 1] = transaction_bytes; - to[synclog_length_prefix + 2] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -644,7 +595,6 @@ void synclog_emit_cluster_transaction_barrier_complete_transaction( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); @@ -652,8 +602,6 @@ void synclog_emit_cluster_transaction_barrier_complete_transaction( to[synclog_length_prefix + 1] = dst_cta_id; to[synclog_length_prefix + 2] = transaction_bytes; to[synclog_length_prefix + 3] = pred; - to[synclog_length_prefix + 4] = bits; - to[synclog_length_prefix + 5] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -977,13 +925,10 @@ void synclog_emit_cpasync_barrier_arrive( #if defined(CUTLASS_ENABLE_SYNCLOG) if constexpr (!synclog_enable_cpasync_barrier_arrive) return; if (!synclog_condition_emit()) return; - uint64_t bits = synclog_mbarrier_bits(smem_addr); uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); if (to == nullptr) return; synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); to[synclog_length_prefix + 0] = smem_addr; - to[synclog_length_prefix + 1] = bits; - to[synclog_length_prefix + 2] = bits >> 32; #else CUTLASS_UNUSED(line); CUTLASS_UNUSED(smem_addr); @@ -1054,7 +999,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_wait) { synclog_print_prefix("cluster_barrier_wait", at); at += synclog_length_cluster_barrier_wait; - printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1062,7 +1007,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_test_wait) { synclog_print_prefix("cluster_barrier_test_wait", at); at += synclog_length_cluster_barrier_test_wait; - printf("smem_addr=%u phase=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u phase=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1070,7 +1015,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_try_wait) { synclog_print_prefix("cluster_barrier_try_wait", at); at += synclog_length_cluster_barrier_try_wait; - printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u phase=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1078,7 +1023,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_arrive_cluster) { synclog_print_prefix("cluster_barrier_arrive_cluster", at); at += synclog_length_cluster_barrier_arrive_cluster; - printf("smem_addr=%u cta_id=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u cta_id=%u pred=%u\n", synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1086,7 +1031,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_arrive) { synclog_print_prefix("cluster_barrier_arrive", at); at += synclog_length_cluster_barrier_arrive; - printf("smem_addr=%u", synclog_buf[at-3]); + printf("smem_addr=%u\n", synclog_buf[at-1]); continue; } } @@ -1094,7 +1039,7 @@ void synclog_print() { if (header == synclog_header_cluster_barrier_invalidate) { synclog_print_prefix("cluster_barrier_invalidate", at); at += synclog_length_cluster_barrier_invalidate; - printf("smem_addr=%u", synclog_buf[at-3]); + printf("smem_addr=%u\n", synclog_buf[at-1]); continue; } } @@ -1102,7 +1047,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; - printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1110,7 +1055,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; - printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1118,7 +1063,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); at += synclog_length_cluster_transaction_barrier_expect_transaction; - printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u transaction_bytes=%u\n", synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1126,7 +1071,7 @@ void synclog_print() { if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); at += synclog_length_cluster_transaction_barrier_complete_transaction; - printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u\n", synclog_buf[at-4], synclog_buf[at-3], synclog_buf[at-2], synclog_buf[at-1]); continue; } } @@ -1283,7 +1228,7 @@ void synclog_print() { if (header == synclog_header_cpasync_barrier_arrive) { synclog_print_prefix("cpasync_barrier_arrive", at); at += synclog_length_cpasync_barrier_arrive; - printf("smem_addr=%u", synclog_buf[at-3]); + printf("smem_addr=%u\n", synclog_buf[at-1]); continue; } } @@ -1302,6 +1247,7 @@ void synclog_print() { //////////////////////////////////////////////////////////////////////////////////////////////////// + #if defined(CUTLASS_ENABLE_SYNCLOG) #undef __syncthreads #define __syncthreads() do {\ @@ -1318,6 +1264,7 @@ void synclog_print() { } while (0) #endif // defined(CUTLASS_ENABLE_SYNCLOG) + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace arch diff --git a/include/cutlass/arch/wmma_sm70.h b/include/cutlass/arch/wmma_sm70.h index bec61172..2c540be8 100644 --- a/include/cutlass/arch/wmma_sm70.h +++ b/include/cutlass/arch/wmma_sm70.h @@ -33,8 +33,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm72.h b/include/cutlass/arch/wmma_sm72.h index 95c639e6..1eb553e8 100644 --- a/include/cutlass/arch/wmma_sm72.h +++ b/include/cutlass/arch/wmma_sm72.h @@ -33,8 +33,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index 10e9d916..c3535ef0 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -33,8 +33,8 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 3a777113..22c17dba 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -39,9 +39,10 @@ #include "cutlass/cutlass.h" #include "cutlass/trace.h" #include +#include "cutlass/arch/synclog.hpp" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #include diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 723f1e3f..0287850b 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -34,14 +34,12 @@ #include #include - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif - -#include "cutlass/cutlass.h" #include "cutlass/functional.h" #include "cutlass/platform/platform.h" #include "cutlass/real.h" diff --git a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp index 0874d8f8..327fc27d 100644 --- a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -279,10 +279,19 @@ public: implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); } - if constexpr (is_grouped_wgrad) { - auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape); - auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape_fallback); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, args.hw_info.cluster_shape_fallback); + // implicit gemm B tile can be small for conv, ensure multicast smem offsets are 128B aligned + int multicast_b_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape)) * sizeof_bits_v; + int multicast_b_fallback_bits = (size<1>(TileShape{}) * size<2>(TileShape{}) / size<0>(cluster_shape_fallback)) * sizeof_bits_v; + implementable &= multicast_b_bits % (128*8) == 0 && multicast_b_fallback_bits % (128*8) == 0; + if (not implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: multicast size too large for B tile\n"); + return false; + } + + if constexpr (is_grouped_wgrad) { implementable &= size<0>(cluster_shape) == 1 && size<0>(cluster_shape_fallback) == 1; if (!implementable) { diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 9363d1e9..16cfa1b3 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -33,15 +33,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif -#include "cutlass/cutlass.h" - namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index ed81aec9..c68a3ba3 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -35,9 +35,14 @@ #pragma once -#include "cutlass/arch/synclog.hpp" #include "cutlass/detail/helper_macros.hpp" +#if (__CUDACC_VER_MAJOR__ >= 13) + #define CUDA_STD_HEADER(header) +#else + #define CUDA_STD_HEADER(header) +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/include/cutlass/detail/blockwise_scale_layout.hpp b/include/cutlass/detail/blockwise_scale_layout.hpp index c05498c5..a304cd6e 100644 --- a/include/cutlass/detail/blockwise_scale_layout.hpp +++ b/include/cutlass/detail/blockwise_scale_layout.hpp @@ -32,7 +32,7 @@ /*! \file - \brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA + \brief Blockwise Scale configs specific for Blockwise/Groupwise MMA */ #pragma once @@ -41,6 +41,7 @@ #include "cute/int_tuple.hpp" #include "cute/atom/mma_traits_sm100.hpp" +#include "cute/arch/mma_sm90.hpp" namespace cutlass::detail{ @@ -270,8 +271,13 @@ struct RuntimeBlockwiseScaleConfig { }; // Sm90 only supports MN major for SFA and SFB for now -template -using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; +template +using Sm90BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig< + SFVecSizeM, + SFVecSizeN, + SFVecSizeK, + majorSFA == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K, + majorSFB == cute::GMMA::Major::MN ? UMMA::Major::MN : UMMA::Major::K>; template using Sm100BlockwiseScaleConfig = Sm1xxBlockwiseScaleConfig; diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index 3847c712..b6a19c94 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -483,6 +483,7 @@ void LayoutAwareConvert( decltype(dst_layout)>::convert(src_vm, dst_vm); } + } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -618,6 +619,32 @@ public: } } + static constexpr uint32_t + compute_tma_transaction_bytes_extra_transform() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(filter_zeros(SmemLayoutScale{})) * size<1>(filter_zeros(SmemLayoutScale{})) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + /// Utilities to copy A and extra inputs from smem to RF template + CUTLASS_DEVICE + static void copy_scale_zeros_for_transform( + cute::tuple & partitioned_transform_extra_info, + int load2transform_consumer_index) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(partitioned_transform_extra_info); + auto&& scales = cute::get<1>(partitioned_transform_extra_info); + using ScaleType = decltype(scales); + auto tSrS = make_tensor(static_cast(scales).data(), scales.layout()); + auto tSsS = cute::get<2>(partitioned_transform_extra_info); + copy(smem_tiled_copy_S, tSsS(_,_,_,_,load2transform_consumer_index), tSrS); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto&& zeros = cute::get<3>(partitioned_transform_extra_info); + using ZeroType = decltype(zeros); + auto tZrZ = make_tensor(static_cast(zeros).data(), zeros.layout()); + auto tZsZ = cute::get<4>(partitioned_transform_extra_info); + copy(smem_tiled_copy_S, tZsZ(_,_,_,_,load2transform_consumer_index), tZrZ); + + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + // Helper functions to select packing for conversion template + CUTLASS_DEVICE + static void dequantize_A_kblock_for_transform( + Tensor const& tArA, + Tensor& tArACompute, + cute::tuple const& partitioned_extra_info, + int const k_block) { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + auto src = tArA(_, _, _, k_block); + auto dst = tArACompute(_, _, _, k_block); + constexpr int num_elements = decltype(size(src))::value; + + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int DstElementsPerReg = 32 / sizeof_bits_v; + using RegArray = cutlass::AlignedArray; + + auto src_arr = recast(src); + auto dst_arr = recast(dst); + + Tensor src_vm = cute::group_modes<1,-1>(cute::zipped_divide(src, pack)); + Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); + + cute::transform(src_arr, dst_arr, Converter::convert); + + if constexpr (ModeHasScales) { + + auto const& scales = cute::get<1>(partitioned_extra_info)(_,_,_,k_block); + + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + + if constexpr (is_same_v) { + + using ScaleArray = cutlass::Array; + auto scale_arr = recast(filter_zeros(scales)); + + if constexpr (is_same_v){ + Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack)); + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i){ + auto&& r = cute::recast(dst_vm(_,i))(0); + auto&& scale_reg = cute::recast(scales_vm(_,i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hmul2(bf16x2_val, + reinterpret_cast(scale_reg[ii])); + } + } + } + else{ + cute::transform(dst_arr, scale_arr, dst_arr, cute::multiplies{}); + } + } + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Do Nothing + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + + auto const& zeros = cute::get<3>(partitioned_extra_info)(_,_,_,k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + + if constexpr (is_same_v) { + using ZeroArray = cutlass::Array; + auto zero_arr = recast(filter_zeros(zeros)); + + if constexpr (is_same_v) { + Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, pack)); + + for (int i = 0; i < size<1>(dst_vm); ++i){ + auto&& r = cute::recast(dst_vm(_,i))(0); + auto&& zero_reg = cute::recast(zeros_vm(_,i))(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hadd2(bf16x2_val, + reinterpret_cast(zero_reg[ii])); + } + } + } + else{ + cute::transform(dst_arr, zero_arr, dst_arr, cute::plus{}); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } +} + /// Utilities for any additional inputs inside of the TMA load template < @@ -998,6 +1171,65 @@ public: } } + template < + class TiledMma, + class TiledCopy, + class TensorStorage + > + CUTLASS_DEVICE + static auto partition_extra_transform_info( + TiledMma const& tiled_mma, + TiledCopy const& smem_tiled_copy_S, + TensorStorage& shared_storage) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + auto smem_thr_copy_S = smem_tiled_copy_S.get_slice(threadIdx.x % 128); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = cta_mma.partition_A(sS); + Tensor tSsS = smem_thr_copy_S.partition_S(tCsS); + Tensor tSrS = make_tensor(tSsS(_,_,_,_,0).shape()); +#if 0 + if(cute::thread(128, 0)){ + print("sS: ");print(sS);print("\n"); + print("tSsS: ");print(tSsS);print("\n"); + print("tSrS: ");print(tSrS);print("\n"); + } +#endif + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = cta_mma.partition_A(sZ); + Tensor tZsZ = smem_thr_copy_S.partition_S(tCsZ); + Tensor tZrZ = make_tensor(tZsZ(_,_,_,_,0).shape()); +#if 0 + if(cute::thread(128, 0)){ + print("sS: ");print(sS);print("\n"); + print("tSsS: ");print(tSsS);print("\n"); + print("tSrS: ");print(tSrS);print("\n"); + print("sZ: ");print(sZ);print("\n"); + print("tZsZ: ");print(tZsZ);print("\n"); + print("tZrZ: ");print(tZrZ);print("\n"); + } +#endif + return cute::make_tuple(smem_tiled_copy_S, tSrS, tSsS, tZrZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + /// Returns the tiled copy and copy views for the extra inputs. template CUTLASS_DEVICE diff --git a/include/cutlass/detail/collective/sm103_kernel_type.hpp b/include/cutlass/detail/collective/sm103_kernel_type.hpp new file mode 100644 index 00000000..04120a41 --- /dev/null +++ b/include/cutlass/detail/collective/sm103_kernel_type.hpp @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Kernel type definitions specific for SM103 BlockScaled MMA +*/ + +#pragma once + +namespace cutlass::sm103::detail { + +enum class KernelPrefetchType { + TmaPrefetch, // TMA Prefetch (is the default version) + Disable // Disable Prefetch +}; + +} // namespace cutlass::sm103::detail diff --git a/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp b/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp new file mode 100644 index 00000000..b6c92c4d --- /dev/null +++ b/include/cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm100MixedInputBlockwiseScaleConfig { + + using ShapeScale = Shape, int32_t>, Shape, int32_t>, int32_t>; + + using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutScale = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layout_scale() { + return LayoutScale{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + smem_atom_layout_scale(CtaShape_MN_K cta_shape_mn_k) { + static_assert(cute::is_static_v, "Expect static CTA shape"); + + int constexpr size_MN = cute::get<0>(CtaShape_MN_K{}); + int constexpr size_K = cute::get<1>(CtaShape_MN_K{}); + + int constexpr SmemSizeMN = (SFVecSizeMN < size_MN) + ? SFVecSizeMN + : size_MN; + + int constexpr SmemSizeK = (SFVecSizeK < size_K) + ? SFVecSizeK + : size_K; + + int constexpr div_MN = cute::ceil_div(size_MN, SmemSizeMN); + int constexpr div_K = cute::ceil_div(size_K, SmemSizeK); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int{})); + } + else { + return make_stride(make_stride(_0{}, Int{}), make_stride(_0{}, _1{})); + } + }(); + + return make_layout( + make_shape(make_shape(Int{}, Int{}), + make_shape(Int{}, Int{})), + strides + ); + } + + + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_scale(ScaledInputDim scale_input_dims) { + const auto scale_input_dims_MNKL = append<3>(scale_input_dims, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [MN, K, L] = scale_input_dims_MNKL; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, SFVecSizeMN))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); + } + }(); + + auto [MN, K, L] = scale_input_dims_MNKL; + auto mk_layout = make_layout( + make_shape(make_shape(Int{}, cute::ceil_div(MN, SFVecSizeMN)), + make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + +}; + +template +struct RuntimeMixedInputBlockwiseScaleConfig { + + using ShapeScale = Shape, Shape, int32_t>; + + using StrideScale = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutScale = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layout_scale() { + return LayoutScale{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_S. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_scale(ProblemShape problem_shape, SFVecShape sf_vec_shape) { + auto problem_shape_MNKL = append<3>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [MN, K, L] = problem_shape_MNKL; + auto [sfmn, sfk] = sf_vec_shape; + if constexpr (majorScale == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(MN, sfmn))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, sfk)), make_stride(_0{}, _1{})); + } + }(); + + auto [MN, K, L] = problem_shape_MNKL; + auto [sfmn, sfk] = sf_vec_shape; + auto mk_layout = make_layout( + make_shape(make_shape(sfmn, cute::ceil_div(MN, sfmn)), + make_shape(sfk, cute::ceil_div(K, sfk))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/detail/sm103_blockscaled_layout.hpp b/include/cutlass/detail/sm103_blockscaled_layout.hpp new file mode 100644 index 00000000..300448d7 --- /dev/null +++ b/include/cutlass/detail/sm103_blockscaled_layout.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Blocked Scale configs specific for SM103 BlockScaled MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm103BlockScaledBasicChunk { + + using Blk_MN = _128; + using Blk_SF = _4; + + using SfKMajorAtom = Layout< Shape< Shape< _8, _4, _4>, Shape, _4>>, + Stride, Stride< _0, _1>>>; + using SfMNMajorAtom = Layout< Shape< Shape, _4>, Shape<_8, _4, _4>>, + Stride, Stride<_16,_128, _4>>>; + using SfAtom = cute::conditional_t; +}; + +template +struct Sm103BlockScaledConfig { + // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + using Sm103BlkScaledChunk = Sm103BlockScaledBasicChunk; + using Blk_MN = typename Sm103BlkScaledChunk::Blk_MN; + using Blk_SF = typename Sm103BlkScaledChunk::Blk_SF; + using SfAtom = typename Sm103BlkScaledChunk::SfAtom; + + using LayoutSF = decltype(tile_to_shape(SfAtom{}, make_shape(int(0),int(0),int(0)),Step<_2,_1,_3>{})); + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSF{}; + } + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSF{}; + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template < class ProblemShape, class LayoutSFA = LayoutSF> + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape, LayoutSFA layout_sfa = LayoutSFA{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(M,K,L), Step<_2,_1,_3>{}); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape, LayoutSFB layout_sfb = LayoutSFB{}) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + return tile_to_shape(SfAtom{}, make_shape(N,K,L), Step<_2,_1,_3>{}); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index 40e19a37..5b1d3e5b 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -35,6 +35,7 @@ #pragma once #include // CUTLASS_HOST_DEVICE +#include // cutlass::arch::synclog_* #include // uint64_t // __grid_constant__ was introduced in CUDA 11.7. diff --git a/include/cutlass/epilogue/collective/builders/sm100_builder.inl b/include/cutlass/epilogue/collective/builders/sm100_builder.inl index b8b752fe..4c3fd96b 100644 --- a/include/cutlass/epilogue/collective/builders/sm100_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm100_builder.inl @@ -50,9 +50,10 @@ #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/operations.hpp" // detail::is_sfd_epilogue_v #include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif @@ -1272,6 +1273,13 @@ private: static constexpr bool Is1SmMma = is_base_of_v; static constexpr bool Is2SmMma = is_base_of_v; + static constexpr bool IsInterleavedComplex = is_complex::value; + static constexpr bool IsFastF32Schedule = is_same_v || + is_same_v || + is_same_v || + is_same_v; + // Input transform kernels - when dispatching to sm100 nosmem epilogue, go through the default path without EVT support. + static constexpr bool IsInputTransformSchedule = IsInterleavedComplex || IsFastF32Schedule; static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); @@ -1315,7 +1323,9 @@ private: static_assert(is_tuple_v, "Shape or Tile"); return EpilogueTileType{}; } - else if constexpr (is_same_v) { // perf specialized case + else if constexpr (is_same_v || not IsInputTransformSchedule) { + // Save register usage for sm103 blockscaled kernels and sm100 cpasync kernels + // to avoid register spilling. constexpr int EpiM = size<0>(CtaTileShape_MNK{}); constexpr int EpiN = cute::min(_64{}, size<1>(CtaTileShape_MNK{})); return Shape, Int>{}; @@ -1333,8 +1343,8 @@ private: static constexpr auto dispatch_policy() { - if constexpr (is_same_v || - is_same_v) { + if constexpr (std::is_base_of_v || + std::is_base_of_v) { return Sm100PtrArrayNoSmemWarpSpecialized{}; } else { @@ -1347,7 +1357,12 @@ private: fusion_callbacks() { constexpr thread::ScaleType::Kind ScaleType = DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; - if constexpr (IsDefaultFusionOp::value && not is_same_v) { + if constexpr (IsDefaultFusionOp::value &&\ + not is_same_v && \ + (IsInputTransformSchedule || \ + is_same_v || \ + is_same_v) + ) { // Legacy codepath using thread::LinearCombination, do not expect this to be stable return thread::LinearCombination< ElementD, 1, ElementAccumulator, ElementCompute, ScaleType, FusionOp::RoundStyle, ElementC>({}); diff --git a/include/cutlass/epilogue/collective/builders/sm103_builder.inl b/include/cutlass/epilogue/collective/builders/sm103_builder.inl new file mode 100644 index 00000000..bd9e5415 --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/sm103_builder.inl @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/layout.hpp" // cute::Shape +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cutlass/arch/mma.h" // cutlass::arch::OpClassTensorOp, cutlass::OpClassSparseTensorOp +#include "cute/atom/copy_traits_sm100.hpp" +#include "cute/atom/mma_traits_sm100.hpp" +#include "cute/util/type_traits.hpp" // cute::is_same_v + +#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::collective { +// Alias to sm100 builder +template < + class OpClass, + class MmaTileShape_MNK, // Static MMA tile shape + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOp +> +struct CollectiveBuilder< + arch::Sm103, + OpClass, + MmaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp +> +{ + using CollectiveOp = typename CollectiveBuilder< + arch::Sm100, + OpClass, + MmaTileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleType, + FusionOp + >::CollectiveOp; +}; + +} // namespace cutlass::epilogue::collective diff --git a/include/cutlass/epilogue/collective/builders/sm120_builder.inl b/include/cutlass/epilogue/collective/builders/sm120_builder.inl index 80e84e9a..6d7e1698 100644 --- a/include/cutlass/epilogue/collective/builders/sm120_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm120_builder.inl @@ -36,10 +36,10 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/collective/builders/sm90_common.inl" #include "cutlass/epilogue/collective/builders/sm120_common.inl" - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index a94d9457..731e9f54 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -45,9 +45,9 @@ #include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(type_traits) #else #include #endif diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index bb55c96d..2bd817a5 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -120,6 +120,7 @@ struct CallbacksBuilder< #include "builders/sm90_builder.inl" #include "builders/sm100_builder.inl" +#include "builders/sm103_builder.inl" #include "builders/sm120_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp index 80eea5e2..d3b2d088 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp @@ -60,7 +60,9 @@ template < class ElementD_, class StrideD_, class ThreadEpilogueOp_, - class CopyOpT2R_ + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ > class CollectiveEpilogue< Sm100PtrArrayNoSmem, @@ -70,7 +72,10 @@ class CollectiveEpilogue< ElementD_, StrideD_, ThreadEpilogueOp_, - CopyOpT2R_ + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> > { public: // @@ -92,6 +97,10 @@ public: using StrideD = StrideD_; using InternalStrideD = cute::remove_pointer_t; using CopyOpT2R = CopyOpT2R_; + using AlignmentC = AlignmentC_; + using AlignmentD = AlignmentD_; + + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -136,7 +145,7 @@ public: template static size_t - get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) { + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { return 0; } @@ -232,24 +241,56 @@ public: } // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, stride_d); // (M,N,L) Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) auto tiled_t2r = make_tmem_copy(CopyOpT2R{}, tensor<0>(accumulators)); - auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) Tensor tTR_rAcc = make_tensor(shape(tTR_gD)); // (T2R,T2R_M,T2R_N) - Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + Tensor tTR_rC = make_tensor(shape(tTR_gC)); // (T2R,T2R_M,T2R_N) + + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + constexpr auto mclD = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gD.layout())){}; + constexpr int VD = cute::min(AlignmentD{}, size(mclD)); + Tensor tTR_rD_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rD_src = recast>(coalesce(tTR_rD_frag)); + Tensor tR2G_rD_dst = recast>(coalesce(tTR_gD)); + + Tensor tTR_cD_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclD.compose(Int{}))); + Tensor tDpD = make_tensor(shape(tR2G_rD_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tDpD); t++) { + tDpD(t) = elem_less(tTR_cD_mn_frg(t), problem_shape_mnl); + } + + constexpr auto mclC = decltype(max_common_layout(tTR_rAcc.layout(), tTR_gC.layout())){}; + constexpr int VC = cute::min(AlignmentC{}, size(mclC)); + + Tensor tTR_cC_mn_frg = tensor<1>(zipped_divide(coalesce(tTR_cCD), mclC.compose(Int{}))); + Tensor tG2R_rC_dst = recast>(coalesce(tTR_gC)); + Tensor tCpC = make_tensor(shape(tG2R_rC_dst)); + + CUTLASS_PRAGMA_UNROLL + for (int t = 0; t < size(tCpC); t++) { + tCpC(t) = elem_less(tTR_cC_mn_frg(t), problem_shape_mnl); + } + Tensor tTR_rC_src = recast>(coalesce(tTR_gC)); + Tensor tTR_rC_dst = recast>(coalesce(tTR_rC)); // Detect interleaved complex fp32 kernels - Tensor accs = accumulators; + [[maybe_unused]] Tensor accs = accumulators; using ElementTmem = typename decltype(accs)::value_type; constexpr bool is_interleaved_complex_f32 = is_complex::value && cute::is_same_v; @@ -283,28 +324,31 @@ public: copy(tiled_t2r, tTR_tAcc, tTR_rAcc); } + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + // 2. Apply element-wise operation and store to gmem // source is needed if (epilogue_op.is_source_needed()) { + copy_if(tCpC, tTR_rC_src, tTR_rC_dst); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); ++i) { - if (elem_less(tTR_cD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rAcc(i), tTR_gC(i)); - } + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i), tTR_rC(i)); } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } // source is not needed, avoid load else { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tTR_rAcc); ++i) { - if (elem_less(tTR_cD(i), problem_shape_mnl)) { - tTR_gD(i) = epilogue_op(tTR_rAcc(i)); - } + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); } + + copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } - cutlass::arch::fence_view_async_tmem_load(); - acc_pipeline.consumer_release(acc_pipe_consumer_state); - ++acc_pipe_consumer_state; + return cute::make_tuple(acc_pipe_consumer_state); } @@ -397,6 +441,439 @@ protected: Params const& params; }; +template < + class EpilogueTile_, // (EPI_TILE_M, EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpT2R_, + class AlignmentC_, + class AlignmentD_ +> +class CollectiveEpilogue< + Sm100PtrArrayNoSmem, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpT2R_, + AlignmentC_, + AlignmentD_, + cute::enable_if_t::value> +> { +public: + // + // Type Aliases + // + // Required by the gemm::kernel + using DispatchPolicy = Sm100PtrArrayNoSmem; + using ElementC = ElementC_; + using ElementD = ElementD_; + using GmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + using StrideC = StrideC_; + using StrideD = StrideD_; + using InternalStrideC = cute::remove_pointer_t; + using InternalStrideD = cute::remove_pointer_t; + using EpilogueTile = EpilogueTile_; + using CopyOpT2R = CopyOpT2R_; + using FusionCallbacks = FusionCallbacks_; + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + +private: + constexpr static bool IsReductionBufferNeeded = ThreadEpilogueOp::IsDePerRowBiasSupported + || is_same_v; // alloc reduction buffer for custom EVTs + constexpr static size_t ImplicitSharedStorageSize = IsReductionBufferNeeded ? size(EpilogueTile{}) : 0; + + // Not unroll epi subtile loop when the activation op is heavy to reduce instruction size and register pressure. + constexpr static bool UnrollEpiLoop = + not cutlass::epilogue::thread::kIsHeavy_member_or_false::value; + +public: + constexpr static int ThreadCount = 128; + constexpr static uint32_t TmaTransactionBytes = 0; + + struct SharedStorage { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + array_aligned buffer; + }; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC = {}; + ElementD** ptr_D = nullptr; + StrideD dD = {}; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC = {}; + ElementD** ptr_D = nullptr; + StrideD dD = {}; + }; + + // + // Constructor and Data Members + // + CUTLASS_DEVICE + CollectiveEpilogue(Params const& params_, SharedStorage& shared_tensors) + : fusion_callbacks(params_.thread, shared_tensors.thread) + , smem_buffer_ptr(shared_tensors.buffer.data()) + , params(params_) {}; + +protected: + FusionCallbacks fusion_callbacks; + uint8_t* smem_buffer_ptr; + Params const& params; + +public: + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + args.ptr_C, + args.dC, + args.ptr_D, + args.dD + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int /*sm_count*/ = 0) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + return fusion_implementable; + } + + + template< + bool ReuseTmem = false, + class AccumulatorPipeline, + class AccumulatorPipelineState, + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout + > + CUTLASS_DEVICE auto + operator()( + AccumulatorPipeline acc_pipeline, + AccumulatorPipelineState acc_pipe_consumer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + [[maybe_unused]] SharedStorage& + ) { + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + // Wait for mma warp to fill tmem buffer with accumulator results + static_assert(is_tmem::value, "Accumulator must be TMEM resident."); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(rank(CtaCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); + + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = cta_coord_mnkl; + + // Batches are managed by using appropriate pointers to C and D matrices + auto problem_shape_mnl = append<3>(make_shape(M,N),Int<1>{}); + auto cta_coord_mnl = append<3>(make_shape(m_coord, n_coord),Int<0>{}); + auto cta_tiler = take<0,2>(cta_tile_mnk); + auto cta_coord_mnk = cute::make_coord(m_coord, n_coord, k_coord, cute::Int<0>{}); + + bool is_C_load_needed = fusion_callbacks.is_C_load_needed(); + + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (is_C_load_needed) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + + // Get the residual tensor for the current batch + ElementC const* ptr_C_l = nullptr; + if (is_C_load_needed) { + ptr_C_l = params.ptr_C[l_coord]; + } + + int thread_idx = threadIdx.x % ThreadCount; + + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + + constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; + + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); + Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) + + Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); + + // Construct the EVT consumer callbacks + auto residue_cD = make_coord(M,N) - cD(_0{}); + auto residue_tTR_cD = make_coord(M,N) - tTR_cD(_0{}); + Tensor cD_ = make_coord_tensor(cD.layout()); + Tensor tTR_cD_ = make_coord_tensor(tTR_cD.layout()); + constexpr bool RefSrc = false; + + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, stride_c); + + Tensor tTR_gC = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mC, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + + Tensor mD = make_tensor(make_gmem_ptr(recast_ptr(params.ptr_D[l_coord])), problem_shape_mnl, stride_d); + + Tensor tTR_gD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + mD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + + // Register Tensor + Tensor tTR_rD = make_tensor(take<0,3>(shape(tTR_gD))); + + Tensor coord_cCD = make_identity_tensor(problem_shape_mnl); + Tensor tTR_cCD = cutlass::epilogue::fusion::sm90_partition_for_epilogue( + coord_cCD, cta_tile_mnk, cta_coord_mnk, EpilogueTile{}, tiled_t2r, thread_idx); + constexpr auto mclD = decltype(max_common_layout(tTR_gD(_,_,_,_0{},_0{}), tTR_rD)){}; + constexpr int VD = cute::min(AlignmentD_{}, size(mclD)); + + auto tCrC = make_tensor(take<0,3>(shape(tTR_gC))); + constexpr auto mclC = decltype(max_common_layout(tTR_gC(_,_,_,_0{},_0{}), tCrC)){}; + constexpr int VC = cute::min(AlignmentC_{}, size(mclC)); + + Tensor tTR_rD_frg = recast>(coalesce(tTR_rD)); + + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + int(0), + EpilogueTile{}, + tiled_t2r, + cD_, + residue_cD, + tTR_cD_, + residue_tTR_cD, + tCrC, + thread_idx + }; + + auto synchronize = [] () CUTLASS_LAMBDA_FUNC_INLINE { cutlass::arch::NamedBarrier::sync(ThreadCount, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // The Epilogue Loop + auto epi_loop_fn = [&] (auto& cst_callbacks) CUTLASS_LAMBDA_FUNC_INLINE { + // Ensure there are no threads from the previous wave writing to shared memory being utilized for the current wave. + synchronize(); + cst_callbacks.begin(); + if (cst_callbacks.begin_sync_needed()) { + synchronize(); + } + + // If tmem doesn't have enough capacity to support double buffering, a portion of tmem (a column of epilogue tiles) + // is overlapped between 2 pseudo-buffers. The shared tmem portion corresponds to the last epilogue tile column of + // tmem accumulator buffer 0, and the first epilogue tile column of tmem accumulator 1. + // Thus, whenever we are processing tmem accumulator buffer 0, we process the epilogue tiles with reversed column order. + // Once the last epilogue tile column is loaded from tmem, the acc_pipeline is released. + // Then, the next accumulation stage for buffer 1 can start. + [[maybe_unused]] bool reverse_epi_n = ReuseTmem && acc_pipe_consumer_state.phase() == 0; + static_assert(not (ReuseTmem && AccumulatorPipeline::Stages != 1), "Tmem reuse requires 1 accumulator stage"); + + // For each epilogue subtile within the CTA tile + constexpr int NumEpiSubtilesN = CUTE_STATIC_V(size<4>(tTR_tAcc)); + constexpr int NumEpiSubtilesM = CUTE_STATIC_V(size<3>(tTR_tAcc)); + + // Lambda to process a single epilogue tile + auto process_tile = [&](int epi_m, int epi_n, int iter_m, int iter_n) CUTLASS_LAMBDA_FUNC_INLINE { + bool is_last_iteration = iter_m == NumEpiSubtilesM-1 && iter_n == NumEpiSubtilesN-1; + bool do_acc_release = is_last_iteration; + + // Adjust release condition for tmem reuse + if constexpr (ReuseTmem) { + do_acc_release = iter_m == NumEpiSubtilesM-1 && iter_n == 0; // Release on first N iteration + } + + Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); + Tensor tTR_pCD_mn = cute::lazy::transform(tTR_cCD_mn, [&] (auto const& c) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(c, problem_shape_mnl); }); + cst_callbacks.begin_loop(epi_m, epi_n); + + if constexpr (not cute::is_void_v) { + if (is_C_load_needed) { + using CVecType = uint_bit_t>; + + if constexpr (!is_same_v) { + Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); + Tensor tTR_rC_frg = recast(coalesce(tCrC)); + Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); + copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); + } + else { + auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); + Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); + Tensor c_dst = thr_g2r.retile_D(tCrC); + Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); + copy_if(tiled_g2r, c_prd, c_src, c_dst); + } + } + } + + // Copy accumulator tile from tmem to register + // The current tile in tmem + Tensor tTR_tAcc_mn = tTR_tAcc(_,_,_,epi_m,epi_n); + + Tensor tTR_rAcc_frg = recast>(coalesce(tTR_rAcc)); + + copy(tiled_t2r, tTR_tAcc_mn, tTR_rAcc); + + // After the last tmem load, signal that tmem buffer is consumed and empty + if (do_acc_release) { + cutlass::arch::fence_view_async_tmem_load(); + acc_pipeline.consumer_release(acc_pipe_consumer_state); + ++acc_pipe_consumer_state; + } + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tTR_rAcc_frg); ++epi_v) { + tTR_rD_frg(epi_v) = cst_callbacks.visit(tTR_rAcc_frg(epi_v), epi_v, epi_m, epi_n); + } + + Tensor reduction_buffer = make_tensor( + raw_pointer_cast(make_smem_ptr(smem_buffer_ptr)), make_layout(Shape>{})); + + cst_callbacks.reduce(reduction_buffer, synchronize, epi_m, epi_n, is_last_iteration, tTR_rAcc /*not used*/); + + cst_callbacks.end_loop(epi_m, epi_n); + + using VecType = uint_bit_t>; + if constexpr (!is_same_v) { + Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); + Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); + Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); + copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); + } + else { + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); + Tensor src = thr_r2g.retile_S(tTR_rD); + Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); + Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); + copy_if(tiled_r2g, prd, src, dst); + } + }; + + // Use static iteration with appropriate ordering + // When ReuseTmem is true and reverse_epi_n is true, we need reverse N iteration + auto n_seq = cute::make_int_sequence{}; + auto m_seq = cute::make_int_sequence{}; + + if constexpr (UnrollEpiLoop) { + // Fully unrolled static iteration + cute::for_each(n_seq, [&](auto I_N) CUTLASS_LAMBDA_FUNC_INLINE { + constexpr int iter_n = I_N; + int epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration + } + } + + cute::for_each(m_seq, [&](auto I_M) CUTLASS_LAMBDA_FUNC_INLINE { + constexpr int iter_m = I_M; + process_tile(iter_m, epi_n, iter_m, iter_n); + }); + }); + } else { + // Runtime loop with pragma unroll(1) + #pragma unroll 1 + for (int iter_n = 0; iter_n < NumEpiSubtilesN; ++iter_n) { + int epi_n = iter_n; + if constexpr (ReuseTmem) { + if (reverse_epi_n) { + epi_n = NumEpiSubtilesN - 1 - iter_n; // Reverse N iteration + } + } + + #pragma unroll 1 + for (int iter_m = 0; iter_m < NumEpiSubtilesM; ++iter_m) { + process_tile(iter_m, epi_n, iter_m, iter_n); + } + } + } + + cst_callbacks.end(); + }; + + // + // BEGIN EPILOGUE + // + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + epi_loop_fn(cst_callbacks); + return cute::make_tuple(acc_pipe_consumer_state); + } + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // For sm100 kernels requiring warp specialized epilogues @@ -430,7 +907,10 @@ class CollectiveEpilogue< ElementD_, StrideD_, ThreadEpilogueOp_, - CopyOpT2R_>> + CopyOpT2R_, + AlignmentC, + AlignmentD, + void>> { public: // ctor inheritance @@ -442,7 +922,10 @@ public: ElementD_, StrideD_, ThreadEpilogueOp_, - CopyOpT2R_>>::Sm100TmaWarpSpecializedAdapter; + CopyOpT2R_, + AlignmentC, + AlignmentD, + void>>::Sm100TmaWarpSpecializedAdapter; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp index d8d99849..90dfb80c 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_nosmem.hpp @@ -343,8 +343,7 @@ public: copy_if(tDpD, tTR_rD_src, tR2G_rD_dst); } // source is not needed, avoid load - else - { + else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tTR_rAcc); i++) { tTR_rD_frag(i) = epilogue_op(tTR_rAcc(i)); @@ -587,8 +586,8 @@ public: CtaTileMNK cta_tile_mnk, CtaCoordMNKL cta_coord_mnkl, cute::Tensor accumulators, - [[maybe_unused]]SharedStorage& - ) { + [[maybe_unused]] SharedStorage& + ) { using ElementAccumulator = typename AccEngine::value_type; using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; @@ -718,10 +717,20 @@ public: if (is_C_load_needed) { using CVecType = uint_bit_t>; + if constexpr (!is_same_v) { Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); Tensor tTR_rC_frg = recast(coalesce(tCrC)); Tensor tTR_pC_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclC.compose(Int{}))); copy_if(tTR_pC_frg, tTR_gC_frg, tTR_rC_frg); + } + else { + auto tiled_g2r = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_g2r = tiled_g2r.get_slice(threadIdx.x); + Tensor c_src = thr_g2r.retile_S(tTR_gC(_,_,_,epi_m,epi_n)); + Tensor c_dst = thr_g2r.retile_D(tCrC); + Tensor c_prd = thr_g2r.retile_D(tTR_pCD_mn); + copy_if(tiled_g2r, c_prd, c_src, c_dst); + } } } @@ -753,10 +762,21 @@ public: cst_callbacks.end_loop(epi_m, epi_n); using VecType = uint_bit_t>; + if constexpr (!is_same_v) { Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); Tensor tTR_pD_frg = tensor<1>(zipped_divide(coalesce(tTR_pCD_mn), mclD.compose(Int{}))); copy_if(tTR_pD_frg, tTR_rD_frg, tTR_gD_frg); + } + else { + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_r2g = tiled_r2g.get_slice(threadIdx.x); + Tensor src = thr_r2g.retile_S(tTR_rD); + Tensor dst = thr_r2g.retile_D(tTR_gD(_,_,_,epi_m,epi_n)); + Tensor prd = thr_r2g.retile_D(tTR_pCD_mn); + copy_if(tiled_r2g, prd, src, dst); + } + } // for epi_m } // for epi_n diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 2e6213fe..c9788a42 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -61,8 +61,12 @@ struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueW // Blackwell direct store schedules struct NoSmemWarpSpecialized1Sm {}; struct NoSmemWarpSpecialized2Sm {}; +struct FastF32NoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct FastF32NoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +struct PtrArrayFastF32NoSmemWarpSpecialized1Sm : PtrArrayNoSmemWarpSpecialized1Sm {}; +struct PtrArrayFastF32NoSmemWarpSpecialized2Sm : PtrArrayNoSmemWarpSpecialized2Sm {}; // Blackwell TMA schedules struct TmaWarpSpecialized1Sm {}; struct TmaWarpSpecialized2Sm {}; @@ -234,12 +238,29 @@ struct Sm100PtrArrayTmaWarpSpecialized { static_assert(StagesD >= 1, "StagesD must be >= 1"); }; -// default elementwise operator epilogue without smem -struct Sm100NoSmem {}; -struct Sm100NoSmemWarpSpecialized {}; -struct Sm100PtrArrayNoSmem {}; -struct Sm100PtrArrayNoSmemWarpSpecialized {}; +struct Sm100NoSmem { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; +struct Sm100NoSmemWarpSpecialized { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; + +struct Sm100PtrArrayNoSmem { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; + +struct Sm100PtrArrayNoSmemWarpSpecialized { + constexpr static int StagesC = 1; + constexpr static int StagesD = 1; + constexpr static int FragmentSize = 1; +}; template< int StagesC_, int StagesD_, diff --git a/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp index d81e3b4d..dfbb75bf 100644 --- a/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm100_callbacks_tma_warpspecialized.hpp @@ -1280,6 +1280,39 @@ struct FusionCallbacks< }; +// -------------------------------------------------------------------- +// Sm100PtrArrayNoSmemWarpSpecialized (direct-store, grouped GEMM) +// -------------------------------------------------------------------- +template < + class Operation, + class CtaTile_MNK, + class EpilogueTile_MN, + class... Args +> +struct FusionCallbacks< + epilogue::Sm100PtrArrayNoSmemWarpSpecialized, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...> + : FusionCallbacks< + // reuse the ptr-array *TMA* callbacks with 0 stages + epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...> { + + using Base = FusionCallbacks< + epilogue::Sm100PtrArrayTmaWarpSpecialized<0,0,0,false,false>, + Operation, + CtaTile_MNK, + EpilogueTile_MN, + Args...>; + + // bring ctors into scope + using Base::Base; +}; } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 72afd1e5..535d8b08 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -871,11 +871,11 @@ struct Sm90ScalarBroadcastPtrArray { template CUTLASS_DEVICE auto get_producer_load_callbacks(ProducerLoadArgs const& args) { - // Get the scalar for batched broadcast - if (size<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } + // Always refresh scalar with the current group index so per-group + // alpha/beta values (provided through pointer arrays) are loaded + // correctly even when the L-stride is zero. + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); return EmptyProducerLoadCallbacks{}; } @@ -904,12 +904,8 @@ struct Sm90ScalarBroadcastPtrArray { > CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - - // Get the scalar for batched broadcast - if (get<2>(params_ptr->dScalar[0]) != 0) { - auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; - update_scalar(l_coord); - } + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); return ConsumerStoreCallbacks(scalar); } @@ -920,13 +916,15 @@ private: int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); if (params_ptr->scalar_ptr_arrays[0] != nullptr) { - scalar = *(params_ptr->scalar_ptr_arrays[0][l_offset]); + // Pointer-array variant: each entry already points to the scalar of a group. + scalar = *(params_ptr->scalar_ptr_arrays[0][l_coord]); } else if (params_ptr->scalar_ptrs[0] != nullptr) { + // Strided pointer variant. scalar = params_ptr->scalar_ptrs[0][l_offset]; } else { - // batch stride is ignored for nullptr fallback + // Literal fallback. scalar = params_ptr->scalars[0]; } @@ -936,15 +934,13 @@ private: for (int i = 1; i < BroadcastCount; ++i) { if (params_ptr->scalar_ptr_arrays[i] != nullptr) { - int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); - scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][rest_l_offset])); + scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][l_coord])); } - if (params_ptr->scalar_ptrs[i] != nullptr) { + else if (params_ptr->scalar_ptrs[i] != nullptr) { int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); } else { - // batch stride is ignored for nullptr fallback scalar = reduction_fn(scalar, params_ptr->scalars[i]); } } diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 49143cf7..d4d286b2 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -38,10 +38,9 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/vector.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_base.h b/include/cutlass/epilogue/threadblock/epilogue_base.h index 57ba7aab..26c8ba82 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_base.h +++ b/include/cutlass/epilogue/threadblock/epilogue_base.h @@ -37,15 +37,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if !defined(__CUDACC_RTC__) #include #include #endif +#include CUDA_STD_HEADER(cassert) -#include - -#include "cutlass/cutlass.h" #include "cutlass/matrix_shape.h" #include "cutlass/numeric_types.h" #include "cutlass/array.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h index e8d6fbcc..17a45387 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_gemm_k_reduction.h @@ -37,10 +37,8 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/vector.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h b/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h index 73213557..4569ee8b 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h +++ b/include/cutlass/epilogue/threadblock/epilogue_smem_accumulator.h @@ -37,10 +37,8 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/numeric_types.h" #include "cutlass/array.h" #include "cutlass/layout/vector.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h index 6a50a500..17cfbcf4 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_streamk_with_broadcast.h @@ -38,16 +38,16 @@ */ #pragma once +#include "cutlass/cutlass.h" -#include +#include CUDA_STD_HEADER(cassert) #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(utility) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h b/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h index 751ce50f..85735240 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_absmax.h @@ -49,16 +49,15 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(utility) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h index 312d43c9..e9cf5e18 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h @@ -38,16 +38,15 @@ */ #pragma once - -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(utility) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h index 5699a23e..81f5567f 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_reduction.h @@ -38,10 +38,10 @@ */ #pragma once - -#include - #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) + + #include "cutlass/array.h" #include "cutlass/numeric_types.h" #include "cutlass/numeric_conversion.h" diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h b/include/cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h new file mode 100644 index 00000000..da363739 --- /dev/null +++ b/include/cutlass/epilogue/threadblock/epilogue_with_scaling_factor.h @@ -0,0 +1,231 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue visitor for threadblock scoped GEMMs that process softmax computations in epilogue. + + The epilogue finds max values in each row of the row-major output matrix and stores them. + The max values are also used for a further round of threadblock scoped reduction operation, where + the partial reduction results are stored in a pre-allocated array and used for further full reduction. + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" // cutlass::TensorRef + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +template +class GemvEpilogueWithScalingFactor +{ + public: + using ThreadShape = ThreadShape_; + using ElementCompute = ElementCompute_; // f32 + using ElementAccumulator = ElementAccumulator_; // f32 + using ElementC = ElementC_; // e2m1 + using ElementD = ElementD_; // e2m1 + using ElementSFD = ElementSFD_; // e4m3 + using LayoutOutput = LayoutOutput_; // ColumnMajor + using LayoutSFD = LayoutSFD_; // ColumnMajor + using TensorRefD = TensorRef; + static constexpr int kVectorSize = kVectorSize_; + // number of threads row + static constexpr int kThreadsPerCol = ThreadShape::kM; // 16 + // number of threads col + static constexpr int kThreadsPerRow = ThreadShape::kN; // 8 + static constexpr int kThreadCount = kThreadsPerCol * kThreadsPerRow; // 128 + + static_assert(kVectorSize == kThreadsPerCol, "vector size and number of threads row should be equal"); + static_assert(std::is_same_v && + std::is_same_v, + "Only support Mx1 (ColumnMajor) output and ColumnMajor scaling factor"); + static_assert(std::is_same_v, "ElementCompute should be float type"); + static_assert(cutlass::sizeof_bits::value == 4, "Output should be FP4 type"); + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFD should be FP8 type"); + static_assert(std::is_same_v, "only support same layout for D and SFD"); + + // Hardcode static_assert on threadshape 16x8 to avoid bug + static_assert(kThreadsPerCol == 16, "thread shape col false"); + static_assert(kThreadsPerRow == 8, "thread shape row false"); + static_assert(kThreadCount == 128, "thread count false"); + + struct Params + { + TensorRefD tensor_d; + ElementSFD *scale_factor_d_ptr{nullptr}; + ElementCompute alpha{0}; + ElementCompute beta{0}; + float st{0}; + int64_t batch_stride_sfd{0}; // Add batch stride for SFD + int64_t stride_d{0}; // Add stride for D tensor + }; + + /// Shared storage + struct SharedStorage + { + // fp32 + // Each thread store one fp32 +#if 1 + ElementAccumulator reduction_buffer[kThreadsPerCol]; +#else + ElementAccumulator reduction_buffer[kThreadCount]; +#endif + // Buffer for collecting 4-bit values for packed store + uint8_t packed_buffer[kThreadsPerCol]; + }; + + private: + Params const ¶ms_; + SharedStorage &shared_storage_; + float st_scale_down{0}; + + public: + CUTLASS_HOST_DEVICE GemvEpilogueWithScalingFactor(Params const ¶ms, SharedStorage &shared_storage) + : params_(params) + , shared_storage_(shared_storage) + { + const float fp_subtype_max = static_cast(cutlass::platform::numeric_limits::max()); + this->st_scale_down = this->params_.st / fp_subtype_max; + } + + CUTLASS_DEVICE void operator()(ElementAccumulator frag_acc, ElementC frag_c, int batch_idx) + { + const int block_idx = blockIdx.x; + const int thread_idx_col = threadIdx.x; + const int thread_idx_row = threadIdx.y; + + const float st_scale_down = this->st_scale_down; + const float st = this->params_.st; + + // Compute D offset using batch_idx and stride_d + const int output_d_base_offset = blockIdx.x * blockDim.y; + const int d_batch_offset = batch_idx * params_.stride_d; + ElementD* output_ptr = ¶ms_.tensor_d.at({output_d_base_offset + d_batch_offset, 0}); + uint8_t* byte_ptr = reinterpret_cast(output_ptr); + // For 8x16 thread layout, 1 thread per 128 threads write to sf d + // Every block write one SFD to gmem + const bool is_write_sfd_thread = (thread_idx_row == 0); + + // Calculate SFD offset using proper batch stride + const int output_sfd_offset = (block_idx / 4) * 512 + block_idx % 4 + batch_idx * params_.batch_stride_sfd; + + auto reduction_buffer = shared_storage_.reduction_buffer; + // fp32 + ElementAccumulator max_accum_row0 = ElementAccumulator(0); + ElementAccumulator max_accum_row1 = ElementAccumulator(0); + + // Thread in row contain duplicate frag_acc data + if ( thread_idx_col == 0 ) { + // 16 threads write to 16 contigious bank, no conflict + reduction_buffer[thread_idx_row] = frag_acc; + } + + __syncthreads(); + + if (threadIdx.y == 0) { + auto acc_0 = reduction_buffer[threadIdx.x * 2]; + auto acc_1 = reduction_buffer[threadIdx.x * 2 + 1]; + // compute the max for me using shuffling among 16 threads. + ElementAccumulator max_accum = fabsf(acc_0); + max_accum = cutlass::fast_max(max_accum, fabsf(acc_1)); + + // Butterfly reduction pattern for 16 threads + // Each iteration halves the number of active lanes + max_accum = cutlass::fast_max(max_accum, __shfl_down_sync(0xFF, max_accum, 4)); // 8->4 + max_accum = cutlass::fast_max(max_accum, __shfl_down_sync(0xFF, max_accum, 2)); // 4->2 + max_accum = cutlass::fast_max(max_accum, __shfl_down_sync(0xFF, max_accum, 1)); // 2->1 + + // Broadcast the final result to all 8 threads + max_accum = __shfl_sync(0xFF, max_accum, 0); + + float pvscale = max_accum * st_scale_down; + ElementSFD qpvscale = static_cast(pvscale); + float qpvscale_up = NumericConverter{}(qpvscale); + float qpvscale_up_rcp = __frcp_rn(qpvscale_up) * st; + uint8_t qval_u8_compare; + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint32_t temp_result; + asm volatile ( + "{\n" + " .reg .f32 output_fp32_0, output_fp32_1;\n" + " .reg .b8 byte0, byte1, byte2, byte3;\n" + " mul.f32 output_fp32_0, %1, %3;\n" + " mul.f32 output_fp32_1, %2, %3;\n" + " cvt.rn.satfinite.e2m1x2.f32 byte0, output_fp32_1, output_fp32_0;\n" + " mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}\n" + : "=r"(temp_result) // Output to uint32_t + : "f"(acc_0), "f"(acc_1), "f"(qpvscale_up_rcp) + ); + qval_u8_compare = temp_result & 0xFF; + #else + ElementD output_fp4_0 = NumericConverter{}(acc_0 * qpvscale_up_rcp); + ElementD output_fp4_1 = NumericConverter{}(acc_1 * qpvscale_up_rcp); + uint8_t raw_fp4_0 = reinterpret_cast(output_fp4_0) & 0x0F; + uint8_t raw_fp4_1 = reinterpret_cast(output_fp4_1) & 0x0F; + qval_u8_compare = (raw_fp4_1 << 4) | raw_fp4_0; + #endif + byte_ptr[threadIdx.x] = qval_u8_compare; + + arch::global_store(qpvscale, + (void *)(params_.scale_factor_d_ptr + output_sfd_offset), + is_write_sfd_thread); + + } + + } // end of operator() +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 4a758b7f..eb14856f 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -30,18 +30,17 @@ **************************************************************************************************/ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #include #include #endif #if !defined(__QNX__) -#include +#include CUDA_STD_HEADER(utility) #endif -#include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/uint128.h" #include "cutlass/coord.h" diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 574202ee..eab0b35f 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -59,12 +59,14 @@ #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUDA_PTX_UE8M0_CVT_ENABLED 1 #endif #if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUDA_PTX_UE8M0_CVT_ENABLED 1 #endif diff --git a/include/cutlass/float_subbyte.h b/include/cutlass/float_subbyte.h index 547714b7..eefab027 100644 --- a/include/cutlass/float_subbyte.h +++ b/include/cutlass/float_subbyte.h @@ -45,12 +45,14 @@ #endif #if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) # define CUDA_PTX_FP4FP6_CVT_ENABLED 1 #endif #if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ - defined(CUTLASS_ARCH_MMA_SM120F_ENABLED)) + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) # define CUDA_PTX_FP4FP6_CVT_ENABLED 1 #endif diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index d705777f..636cb8ca 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -54,7 +54,8 @@ #include #endif // _MSC_VER -#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) # define CUTLASS_ARCH_CREDUX_ENABLED #endif diff --git a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl index 8617e883..0566905d 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl @@ -132,7 +132,7 @@ auto sm100_make_simt_gmem_tiled_copy_SFA() { return make_tiled_copy( SmemScalingCopyAtomA{}, Layout>{}, // 32 threads - Layout, Int>>, Stride>>{}); + Layout>>{}); } else { using SmemScalingCopyAtomA = Copy_Atom, Element>; @@ -166,7 +166,7 @@ auto sm100_make_simt_gmem_tiled_copy_SFB() { return make_tiled_copy( SmemScalingCopyAtomB{}, Layout>{}, // 32 threads - Layout, Int>>, Stride>>{}); + Layout>>{}); } else { using SmemScalingCopyAtomB = Copy_Atom, Element>; diff --git a/include/cutlass/gemm/collective/builders/sm100_common.inl b/include/cutlass/gemm/collective/builders/sm100_common.inl index 385fea22..2a5922e2 100644 --- a/include/cutlass/gemm/collective/builders/sm100_common.inl +++ b/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -569,6 +569,73 @@ sm100_make_trivial_fastFP32_tiled_mma() { } } +//Setting mma for Mixed input gemm. Here, ElementAMma should be TACompute +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class KernelScheduleType +> +constexpr auto +sm100_make_trivial_mixed_input_tiled_mma() { + constexpr int M = cute::size<0>(TileShape_MNK{}); + constexpr int N = cute::size<1>(TileShape_MNK{}); + //MMA 1Sm requested + if constexpr (cute::is_base_of_v ) { + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + if constexpr (cute::is_same_v || cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_TS{}); + } + if constexpr (cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F8F6F4_TS{}); + } + } + else { // If A needs to be transposed by MMA, fall back to SMEM from A MMA instructions + if constexpr (cute::is_same_v || cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_SS{}); + } + if constexpr (cute::is_same_v) { + return make_tiled_mma( + cute::MMA_Traits< + cute::SM100_MMA_F8F6F4_SS, + ElementAMma, + ElementBMma, + ElementAccumulator, + cute::C, + cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>{}); + } + } + } + //MMA 2Sm requested + else if constexpr (cute::is_base_of_v) { + if constexpr (UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v) { + if constexpr (cute::is_same_v || cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F16BF16_2x1SM_TS{}); + } + if constexpr (cute::is_same_v) { + return make_tiled_mma(cute::SM100_MMA_F8F6F4_2x1SM_TS{}); + } + } + } + else { + static_assert(cutlass::detail::dependent_false == 0, + "Unsupported policy for SM100 collective builder."); + } +} + template< class CtaShape_MNK > diff --git a/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl new file mode 100644 index 00000000..6a30b41b --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + StageCountType, + BuilderScheduleTag, + cute::enable_if_t || + (cute::is_same_v && + (((sizeof(ElementA) * AlignmentA) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0) || + ((sizeof(ElementB) * AlignmentB) % cutlass::gemm::collective::detail::tma_alignment_bytes != 0)))> +> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + using ElementAMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementAMma>; + using ElementBMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementAMma, ElementBMma, ElementAccumulator, + decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, BuilderScheduleTag>()); + + using AtomThrID = typename TiledMma::AtomThrID; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + // Assigning 4 warps for mainloop load + static constexpr int NumLoadThreads = 128; + + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * AlignmentA>; + using GmemCopyAtomA = cute::Copy_Atom, ElementA>; + using GmemTiledCopyA = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomA, NumLoadThreads, AlignmentA, TagToStrideA_t, + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, BlockTileA_M, BlockTileA_K>()); + + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * AlignmentB>; + using GmemCopyAtomB = cute::Copy_Atom, ElementB>; + using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< + GmemCopyAtomB, NumLoadThreads, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, BlockTileB_N, BlockTileB_K>()); + + static constexpr uint32_t AccumulatorPipelineStageCount = 2; + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + CLCResponseStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm100UmmaCpAsyncWarpSpecialized< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl new file mode 100644 index 00000000..eed95105 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementAMma, + class ElementScale, + class ElementZero, + class ElementB, + class CtaTileShape_MNK, + class TiledMma, + class KernelScheduleType, + UMMA::Major UmmaMajorA, + int ScaleGranularityK, + int stages +> +constexpr cute::tuple +sm100_compute_stage_count_or_override_mixed_input(StageCount stage_count) { + constexpr int Load2TransformStageCount = stages; + constexpr int Transform2MmaStageCount = stages; + constexpr int AccumulatorStageCount = stages; + return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount); +} + +template< + int CapacityBytes, + class ElementA, + class ElementAMma, + class ElementScale, + class ElementZero, + class ElementB, + class CtaTileShape_MNK, + class TiledMma, + class KernelScheduleType, + UMMA::Major UmmaMajorA, + int ScaleGranularityK, + int carveout_bytes +> +constexpr cute::tuple +sm100_compute_stage_count_or_override_mixed_input(StageCountAutoCarveout stage_count) { + + constexpr int CtaM = get<0>(CtaTileShape_MNK{}); + constexpr int CtaN = get<1>(CtaTileShape_MNK{}); + static_assert(CtaN <= 128, "Can't support CtaN>128 tiles"); + constexpr int CtaK = get<2>(CtaTileShape_MNK{}); + using AtomThrID = typename TiledMma::AtomThrID; + + constexpr int TmemColumns = 512; + + constexpr bool IsAComputeinTmem = UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v; + constexpr bool IsAComputeinSmem = !IsAComputeinTmem; + + // Detect 2x2 TMEM layout + constexpr int TmemAccWordsPerDP = (CtaM == 64 && size(AtomThrID{}) == 2) ? CtaN/2 : CtaN; + constexpr int TmemAWordsPerDP = CtaK / 2; + + constexpr int AccumulatorStageCount = (IsAComputeinTmem) ? ((TmemAccWordsPerDP == 128) ? 2 : 3) : (TmemColumns / TmemAccWordsPerDP); + + constexpr int SmemCapacityAfterMma2AccumCarveout = CapacityBytes - (carveout_bytes + AccumulatorStageCount * 32); + + constexpr int TmemInAStageCount_Potential = (IsAComputeinTmem) ? (TmemColumns - AccumulatorStageCount * TmemAccWordsPerDP) / TmemAWordsPerDP : 10000; + + // Mainload2Transform Pipeline + constexpr auto load2transform_pipeline_bytes = sizeof(typename cutlass::PipelineTmaTransformAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; // ElementA introduce here + constexpr auto s_bits = cute::is_void_v ? 0 : cute::sizeof_bits_v; + constexpr auto z_bits = cute::is_void_v ? 0 : cute::sizeof_bits_v; + + constexpr auto load2mma_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage); + constexpr auto b_bits = cute::sizeof_bits_v; // ElementB introduce here + + constexpr int ab_stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + + cutlass::bits_to_bytes(s_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK) + + cutlass::bits_to_bytes(z_bits * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{}) / ScaleGranularityK) + + cutlass::bits_to_bytes(b_bits * size<1>(CtaTileShape_MNK{}) / size(AtomThrID{}) * size<2>(CtaTileShape_MNK{})) + + static_cast(load2transform_pipeline_bytes) + static_cast(load2mma_pipeline_bytes); + + // Transform2Mma Pipeline + constexpr auto transform2mma_pipeline_bytes = sizeof(typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage); + constexpr auto a_compute_bits = cute::sizeof_bits_v; + constexpr int ab_compute_stage_bytes = + cutlass::bits_to_bytes(a_compute_bits * int(IsAComputeinSmem) * size<0>(CtaTileShape_MNK{}) * size<2>(CtaTileShape_MNK{})) + // If ACompute is in TMEM, Acompute buffer has 0 bytes. + static_cast(transform2mma_pipeline_bytes); + + constexpr int ABComputeStageCount_Potential = SmemCapacityAfterMma2AccumCarveout / (ab_stage_bytes + ab_compute_stage_bytes); + + // The number of SMEM buffers for A, B. ACompute (if in SMEM), BCompute should be at least Transform2MmaStageCount + constexpr int Transform2MmaStageCount = std::min(TmemInAStageCount_Potential, ABComputeStageCount_Potential); + + constexpr int SmemCapacityAfterABComputeCarveout = SmemCapacityAfterMma2AccumCarveout - (Transform2MmaStageCount * ab_compute_stage_bytes); + + // Can we boost the number of buffers for A and B? + constexpr int Load2TransformStageCount = SmemCapacityAfterABComputeCarveout / ab_stage_bytes; + + static_assert(Load2TransformStageCount >= 2 && Transform2MmaStageCount >= 2 && AccumulatorStageCount >= 2, "Not enough SMEM or TMEM capacity for selected tile size"); + return cute::make_tuple(Load2TransformStageCount, Transform2MmaStageCount, AccumulatorStageCount); +} + +} // namespace detail + +template +constexpr int get_ScaleGranularityK() { + if constexpr (cute::is_void_v) { + return 1; + } else { + return size<1,0>(LayoutScale{}); + } +} + + +// Mixed Input MMA kernels builder +template < + class ElementAOptionalTuple, + class GmemLayoutATagTuple, + int AlignmentA, + class ElementBOptionalTuple, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, // The Cluster-level TileShape + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ElementAOptionalTuple, // ElementA + GmemLayoutATagTuple, // LayoutA + AlignmentA, + ElementBOptionalTuple, // ElementB + GmemLayoutBTag, // LayoutB + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, int) + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_base_of_v) && + ((sizeof(float) * AlignmentA) % detail::tma_alignment_bytes == 0) && + ((sizeof(float) * AlignmentB) % detail::tma_alignment_bytes == 0)>> +{ + using GmemLayoutATag = detail::deduce_mixed_width_dtype_t<0, GmemLayoutATagTuple>; + using GmemLayoutScaleTag = detail::deduce_mixed_width_dtype_t<1, GmemLayoutATagTuple>; + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + using ElementScale = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ElementZero = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + + static constexpr bool NeitherIsTuple = !cute::is_tuple::value && !cute::is_tuple::value; + static constexpr bool IsANarrow = cute::sizeof_bits_v < cute::sizeof_bits_v; + static constexpr bool IsMixedInput = cute::sizeof_bits_v != cute::sizeof_bits_v; + static_assert(IsMixedInput, "Mixed Input GEMM Kernel doesn't support regular gemm."); + + static_assert((cute::is_tuple::value ^ cute::is_tuple::value || + (NeitherIsTuple && (cute::sizeof_bits::value != cute::sizeof_bits::value))), + "Either A OR B must be a tuple or the widths of A and B must be different."); + using ElementPairA = cute::conditional_t, ElementAOptionalTuple>; + using ElementPairB = cute::conditional_t, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + static_assert(IsATransformed, "A matrix should be transformed."); + + // For fp32 types, map to tf32 MMA value type. + using ElementMma = cute::conditional_t, tfloat32_t, ElementB>; + + + using ElementAMma = ElementMma; + using ElementBMma = ElementMma; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + + static constexpr int ScalingFactor = 1; + + using TiledMma = decltype(detail::sm100_make_trivial_mixed_input_tiled_mma()); + using AtomThrID = typename TiledMma::AtomThrID; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using CtaTileShape_MNK = decltype(shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(cute::size<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm100_cluster_shape_to_tma_atom_B(ClusterShape_MNK{}, AtomThrID{})); + + // Input transform kernel can not use TMA 2SM instructions. + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomACompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomA, SmemLayoutAtomACompute>; + static constexpr int MMA_M = cute::size<0,0>(MmaShapeA_MK{}); + using CopyAtomPairA = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementA>, + cute::conditional_t<(UmmaMajorA == cute::UMMA::Major::K && !cute::is_base_of_v), + cute::conditional_t<(MMA_M == 64 && size(AtomThrID{}) == 1), SM100_TMEM_STORE_16dp256b1x, SM100_TMEM_STORE_32dp32b8x>, // TS Implementation + Copy_Atom, ElementA>> // SS Implementation + >; + + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + // Input transform kernel can not use TMA 2SM instructions. + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomBCompute = decltype(cutlass::gemm::collective::detail::sm100_smem_selector()); + using SmemLayoutAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedLayoutAtomType< + SmemLayoutAtomB, SmemLayoutAtomBCompute>; + using CopyAtomPairB = cutlass::gemm::collective::detail::CollectiveMmaEmulatedCopyType< + Copy_Atom, ElementB>, + Copy_Atom, ElementMma> + >; + + //Creating the stride of Transformed Input + using StrideA = cutlass::gemm::TagToStrideA_t; + using LayoutScale = cutlass::gemm::TagToStrideA_t; + + using VoidShapeScale = Shape, _1>, Shape, _1>, _1>; //Dummy Value to create a dummy ScaleConfig + using VoidStrideScale = Stride,Stride<_0, _1>, _1>; + using VoidLayoutScale = Layout; + + using NonVoidLayoutScale = cute::conditional_t< + cute::is_void_v, VoidLayoutScale, LayoutScale>; + + using StridePairA = decltype(cute::make_tuple(StrideA{}, NonVoidLayoutScale{})); + + // SmemCarveout + static constexpr int SchedulerPipelineStageCount = 3; + static constexpr bool IsArrayOfPointersGemm = (cute::is_base_of_v); + + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = sizeof(uint32_t); + // Tensormap Storage + static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; + + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( CLCPipelineStorage + + CLCResponseStorage + + CLCThrottlePipelineStorage + + TmemDeallocStorage + + TmemBasePtrsStorage + + TensorMapStorage); + + // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations + static constexpr int Sm100ReducedSmemCapacityBytes = detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + static constexpr int ScaleGranularityK = get_ScaleGranularityK(); + + static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_mixed_input< + Sm100ReducedSmemCapacityBytes, TmaElementA, ElementAMma, ElementScale, ElementZero, ElementB, CtaTileShape_MNK, TiledMma, KernelScheduleType, UmmaMajorA, ScaleGranularityK>(StageCountType{}); + + static constexpr int Load2TransformPipelineStageCount = get<0>(stage_info); + static constexpr int Transform2MmaPipelineStageCount = get<1>(stage_info); + static constexpr int AccumulatorPipelineStageCount = get<2>(stage_info); + + static_assert(!IsArrayOfPointersGemm, "mixed input does not support grouped gemm on Blackwell"); + + using DispatchPolicy = cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount, + Transform2MmaPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK + >; + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementPairA, + StridePairA, + ElementPairB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomPairA, + CopyAtomPairA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomPairB, + CopyAtomPairB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index e7f5235a..5edf637e 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -184,6 +184,7 @@ struct CollectiveBuilder< not cute::is_complex_v && not cute::is_complex_v && // Dense Gemm / PtrArrayDenseGemm ( + (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v)) && // Alignment check diff --git a/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl new file mode 100644 index 00000000..83f106c2 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl @@ -0,0 +1,546 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int stages +> +constexpr int +sm103_compute_stage_count_or_override_blockscaled(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template < + int CapacityBytes, + class ElementA, + class ElementB, + class TileShapeMNK, + class TileShapeSFA, + class TileShapeSFB, + int carveout_bytes +> +constexpr auto +sm103_compute_stage_count_or_override_blockscaled(StageCountAutoCarveout stage_count) { + // For F8F6F4 MMA sub-bytes, ElementA/B will be passed in as uint8_t + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one MainloopPipeline = PipelineTmaUmmaAsync (CollectiveMma::SharedStorage::SharedStorage) + // 3. smem for SFB and smem for SFB (CollectiveMma::SharedStorage::TensorStorage, independent of input size b.c. sizeof(sf) is fixed) + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto stage_sfa_bytes = size(filter_zeros(TileShapeSFA{})); + constexpr auto stage_sfb_bytes = size(filter_zeros(TileShapeSFB{})); + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes * 2 + stage_sfa_bytes + stage_sfb_bytes); + + constexpr int ab_buffer = (CapacityBytes - carveout_bytes) / stage_bytes; + constexpr int sb_buffer = ab_buffer + (CapacityBytes - carveout_bytes - ab_buffer * stage_bytes) / (mainloop_pipeline_bytes + stage_sfa_bytes + stage_sfb_bytes); + return make_tuple(ab_buffer, sb_buffer); +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + int SFVectorSize +> +constexpr auto +sm103_make_blockscaled_1sm_tiled_mma() { + using AtomLayout_MNK = Layout; + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 128 || M == 256, "Invalid TileShape_M."); + + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 64 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM103::SM103_MXF4_ULTRA_SS_VS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM103 collective builder."); + } +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + int SFVectorSize +> +constexpr auto +sm103_make_blockscaled_2sm_tiled_mma() { + using AtomLayout_MNK = Layout{}))>; + + constexpr int M = cute::size<0>(TileShape_MNK{}); + static_assert(M == 128 || M == 256, "Invalid TileShape_M."); + + constexpr int N = cute::size<1>(TileShape_MNK{}); + static_assert(N % 64 == 0 && N <= 256, "Invalid TileShape_N."); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_tiled_mma(cute::SM103::SM103_MXF4_ULTRA_2x1SM_SS_VS{}); + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for SM103 collective builder."); + } +} + + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class ClusterTileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + class BuilderScheduleTag +> +constexpr auto +sm103_make_blockscaled_tiled_mma() { + constexpr uint32_t SFVectorSize = find_vector_size(); + + // MMA_2SM requested + if constexpr (cute::is_base_of_v) { + return sm103_make_blockscaled_2sm_tiled_mma(); + } + // MMA_1SM requested + else if constexpr (cute::is_base_of_v) { + return sm103_make_blockscaled_1sm_tiled_mma(); + } + // Auto scheduling requested + else if constexpr (cute::is_same_v) { + if constexpr (cute::get<0>(ClusterShape_MNK{}) % 2 == 0) { + return sm103_make_blockscaled_2sm_tiled_mma(); + } + else { + return sm103_make_blockscaled_1sm_tiled_mma(); + } + } + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported policy for SM103 collective builder."); + } +} + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + uint32_t SFVectorSize, + class BuilderScheduleTag, + bool Is2SM +> +struct Sm103TrivialBlockscaledMma {}; + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + uint32_t SFVectorSize, + class BuilderScheduleTag +> +struct Sm103TrivialBlockscaledMma< ElementAMma, + ElementBMma, + ElementAccumulator, + ElementSF, + TileShape_MNK, + ClusterShape_MNK, + UmmaMajorA, + UmmaMajorB, + SFVectorSize, + BuilderScheduleTag, + true /*Is2SM*/> { + using type = decltype(sm103_make_blockscaled_2sm_tiled_mma()); + }; + +template< + class ElementAMma, + class ElementBMma, + class ElementAccumulator, + class ElementSF, + class TileShape_MNK, + class ClusterShape_MNK, + UMMA::Major UmmaMajorA, + UMMA::Major UmmaMajorB, + uint32_t SFVectorSize, + class BuilderScheduleTag +> +struct Sm103TrivialBlockscaledMma< ElementAMma, + ElementBMma, + ElementAccumulator, + ElementSF, + TileShape_MNK, + ClusterShape_MNK, + UmmaMajorA, + UmmaMajorB, + SFVectorSize, + BuilderScheduleTag, + false /*Is2SM*/> { + using type = decltype(sm103_make_blockscaled_1sm_tiled_mma()); +}; + +template +CUTLASS_HOST_DEVICE +static constexpr bool +is_sm103_block_scale_input() { + // Allowed input element datatype for block-scaling GEMM + return ( cute::is_same_v || + cute::is_same_v); +} + +template +constexpr +auto sm103_sfa_smem_atom_layout() { + constexpr int SF_BUFFERS_PER_TILE_K = BlockScaleConfig::SFVecSize == 16 ? 4 : 2; + auto mma_sfa_tiler = make_shape(get<0,0>(MmaShapeA_MK{})*get<1>(MmaShapeA_MK{}), get<0,1>(MmaShapeA_MK{}) * get<2>(MmaShapeA_MK{}) / Int{}); + return tiled_product(typename BlockScaleConfig::SfAtom{}, + make_layout(shape_div(mma_sfa_tiler, product_each(shape(typename BlockScaleConfig::SfAtom{}))))); +} + +template +constexpr +auto sm103_sfb_smem_atom_layout() { +auto sSFB = [&]() { + constexpr int MMA_N = get<0>(MmaShapeB_NK{}); + constexpr int NonPow2N = 192; + constexpr int NonPow2N_RoundUp = 256; + // If MMA_N is 192, we need to operate at MMA_N = 256 granularity for UTCCP to work for ScaleFactorB. + // Both TMA and UTCCP will transfer scale factor B as if we have 256 columns in B matrix. + constexpr int MMA_N_SFB = (MMA_N == NonPow2N) ? NonPow2N_RoundUp : MMA_N; + constexpr int SF_BUFFERS_PER_TILE_K = BlockScaleConfig::SFVecSize == 16 ? 4 : 2; + auto mma_sfb_tiler = make_shape(Int{}, get<1>(MmaShapeB_NK{}) / Int{}); + if constexpr(Int{} == Int<128>{}) { + return tiled_product(typename BlockScaleConfig::SfAtom{}, + make_layout(shape_div(mma_sfb_tiler,product_each(shape(typename BlockScaleConfig::SfAtom{}))))); + + } + else { + using SfKMajorAtom256 = Layout< Shape< Shape<_32,_4, _2>, Shape, _4>>, + Stride(mma_sfb_tiler)/SFVecSize/4*512>>, Stride< _0, _1>>>; + return tiled_product(SfKMajorAtom256{}, + make_layout(shape_div(mma_sfb_tiler,product_each(shape(SfKMajorAtom256{}))))); + } + }(); + return sSFB; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementPairA, + class GmemLayoutATag, + int AlignmentA, + class ElementPairB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class StageCountType, + class BuilderScheduleTag +> +struct CollectiveBuilder< + arch::Sm103, + arch::OpClassBlockScaledTensorOp, + ElementPairA, + GmemLayoutATag, + AlignmentA, + ElementPairB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + BuilderScheduleTag, + cute::enable_if_t< + // Not paired input, Not Complex input + (cute::is_tuple_v && cute::is_tuple_v && + not cute::is_complex_v && not cute::is_complex_v) && + // Blockscaled Gemm + (cute::is_base_of_v || + cute::is_base_of_v || + cute::is_same_v) && + // Alignment check + detail::sm1xx_blockscaled_gemm_is_aligned(ElementPairA{}))>, + AlignmentA, + remove_cvref_t(ElementPairB{}))>, + AlignmentB, + BuilderScheduleTag>()>> +{ + using ElementA = remove_cvref_t(ElementPairA{}))>; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using ElementSF = remove_cvref_t(ElementPairA{}))>; + + static_assert(cute::is_tuple::value, "Expecting ElementPairA to be a tuple."); + static_assert(cute::is_tuple::value, "Expecting ElementPairB to be a tuple."); + + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(cute::size<2>(TileShape_MNK{}) == _768{}, "TileShape_K should 768 for MMA kernels"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + static_assert(cutlass::gemm::detail::is_k_major_A() && cutlass::gemm::detail::is_k_major_B(), "Only K major inputs are supported"); + + static_assert(cutlass::gemm::collective::detail::is_sm103_block_scale_input(), "Incorrect type for A matrix"); + static_assert(cutlass::gemm::collective::detail::is_sm103_block_scale_input(), "Incorrect type for B matrix"); + + static_assert(cute::is_same_v || + cute::is_same_v, "Incorrect scale factor type"); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + static constexpr uint32_t SFVectorSize = detail::find_vector_size(); + + static constexpr bool is_2sm = cute::is_base_of_v || + (cute::is_same_v && + (cute::is_static_v && cute::get<0>(ClusterShape_MNK{}) % 2 == 0)); + + using TiledMma = typename cutlass::gemm::collective::detail::Sm103TrivialBlockscaledMma::type; + + using AtomThrID = typename TiledMma::AtomThrID; + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + + using ElementAMma_SmemAllocType = uint8_t; + // ElementAMma; + using ElementBMma_SmemAllocType = uint8_t; + // ElementBMma; + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopySFA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopySFB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_SFB( + ClusterShape_MNK{}, AtomThrID{})); + + using GmemTiledCopyPairA = decltype(cute::make_tuple(GmemTiledCopyA{}, GmemTiledCopySFA{})); + using GmemTiledCopyPairB = decltype(cute::make_tuple(GmemTiledCopyB{}, GmemTiledCopySFB{})); + + // + // Construct SMEM layout (SmemLayoutAtom) for A and SFA + // + using SmemLayoutAtomA = UMMA::Layout_K_SW128_Atom; + // A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + static constexpr int MMA_M = cute::size<0>(TileShape_MNK{}) / cute::size(AtomThrID{}); + using SmemLayoutAtomSFA = decltype(detail::sm103_sfa_smem_atom_layout()); + using SmemLayoutAtomsA = decltype(cute::make_tuple(SmemLayoutAtomA{}, SmemLayoutAtomSFA{})); + + // + // Construct SMEM layout(SmemLayoutAtom)for B and SFB + // + + using SmemLayoutAtomB = UMMA::Layout_K_SW128_Atom; + static constexpr int MMA_N = cute::size<1>(TileShape_MNK{}); + // If MMA_N is 192, we need to operate at MMA_N = 256 granularity for UTCCP to work for ScaleFactorB. + // Both TMA and UTCCP will transfer scale factor B as if we have 256 columns in B matrix. + using SmemLayoutAtomSFB = decltype(detail::sm103_sfb_smem_atom_layout(TileShape_MNK{})),SFVectorSize>()); + using SmemLayoutAtomsB = decltype(cute::make_tuple(SmemLayoutAtomB{}, SmemLayoutAtomSFB{})); + + // + // Construct Strides for A, SFA, B, and SFB + // + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using InternalStrideA = cute::remove_pointer_t; + using InternalStrideB = cute::remove_pointer_t; + using InternalLayoutSFA = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFA()); + using InternalLayoutSFB = decltype(Sm1xxBlkScaledConfig::deduce_layoutSFB()); + using LayoutSFA = cute::conditional_t, InternalLayoutSFA, InternalLayoutSFA *>; + using LayoutSFB = cute::conditional_t, InternalLayoutSFB, InternalLayoutSFB *>; + using StridePairA = decltype(cute::make_tuple(StrideA{}, LayoutSFA{})); + using StridePairB = decltype(cute::make_tuple(StrideB{}, LayoutSFB{})); + + // + // Others + // + + static constexpr cutlass::sm103::detail::KernelPrefetchType PrefetchType = cute::is_base_of_v + || cute::is_base_of_v + ? cutlass::sm103::detail::KernelPrefetchType::Disable : + cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch; + + static constexpr uint32_t AccumulatorPipelineStageCount = (MMA_N == 256) ? 1 : 2; + static constexpr uint32_t SchedulerPipelineStageCount = 3; + + // AccumulatorPipeline = PipelineUmmaAsync + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + // CLCPipeline = PipelineCLCFetchAsync + static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); + // LoadOrderBarrier = OrderedSequenceBarrier<1,2> + static constexpr auto LoadOrderBarrierStorage = sizeof(typename cutlass::OrderedSequenceBarrier<1,2>::SharedStorage); + // CLC (scheduler) response + static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // CLC Throttle pipeline storage + static constexpr auto CLCThrottlePipelineStorage = sizeof(typename cutlass::PipelineAsync::SharedStorage); + // Tmem dealloc + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem ptr storage + static constexpr auto TmemBasePtrsStorage = AccumulatorPipelineStageCount * sizeof(uint32_t); + // Tensormap Storage + static constexpr bool IsArrayOfPointersGemm = cute::is_base_of_v; + static constexpr auto TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 4 /* for A, B, SFA and SFB */ : 0; + // TMA Load Prefetch Storage + static constexpr auto TmaPrefetchStorage = 0; + // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage + static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + + CLCPipelineStorage + + LoadOrderBarrierStorage + + CLCResponseStorage + + CLCThrottlePipelineStorage + + TmemDeallocStorage + + TmemBasePtrsStorage + + TensorMapStorage + + TmaPrefetchStorage); + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape, Int, _128>; // SmemAllocTypes are uint8_t. We always allocate 128bytes + static constexpr auto PipelineStages = cutlass::gemm::collective::detail::sm103_compute_stage_count_or_override_blockscaled< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + + using DispatchPolicy = typename cute::conditional_t(PipelineStages), + get<1>(PipelineStages), + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK, + PrefetchType + >, + cutlass::gemm::MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + get<0>(PipelineStages), + get<1>(PipelineStages), + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK, + PrefetchType + > + >; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementPairA, + StridePairA, + ElementPairB, + StridePairB, + TiledMma, + GmemTiledCopyPairA, + SmemLayoutAtomsA, + void, + cute::identity, + GmemTiledCopyPairB, + SmemLayoutAtomsB, + void, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm1xx_common.inl b/include/cutlass/gemm/collective/builders/sm1xx_common.inl index a6444e02..f63842c0 100644 --- a/include/cutlass/gemm/collective/builders/sm1xx_common.inl +++ b/include/cutlass/gemm/collective/builders/sm1xx_common.inl @@ -141,6 +141,18 @@ constexpr uint32_t find_vector_size() { cute::is_same_v || cute::is_same_v || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v ) { return 16; } diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index c734671b..ad2667dd 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -47,6 +47,9 @@ #include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_simt_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl" #include "cutlass/gemm/collective/builders/sm120_sparse_mma_builder.inl" diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 0bb6f722..f698b79c 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -63,6 +63,10 @@ #include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp" #include "cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp" +#include "cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp" #include "cutlass/gemm/collective/sm120_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_mma_tma.hpp" #include "cutlass/gemm/collective/sm120_blockscaled_mma_array_tma.hpp" diff --git a/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp new file mode 100644 index 00000000..e744ffb6 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/arch/memory.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100UmmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); + static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); + + using DispatchPolicy = MainloopSm100UmmaCpAsyncWarpSpecialized< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + // TileShape refers to MmaTileShape to adapt for runtime cluster shape + using TileShape = TileShape_; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + // Define A and B block shapes + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); + using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + + // CtaShape_MNK is queried from collective in all kernel layers + using CtaShape_MNK = TileShape; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = StrideA_; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = StrideB_; + + static constexpr bool IsRuntimeDataTypeA = cute::is_same_v; + static constexpr bool IsRuntimeDataTypeB = cute::is_same_v; + + static_assert(IsRuntimeDataTypeA == IsRuntimeDataTypeB, + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineUmmaConsumerAsync; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + static_assert(size(GmemTiledCopyA{}) == size(GmemTiledCopyB{}), "A and B GmemTiledCopy should share the same thread count"); + static constexpr int NumLoadThreads = size(GmemTiledCopyA{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using MmaSmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + append(LoadShapeA_MK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using MmaSmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + using LoadSmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + append(LoadShapeB_NK{}, Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + return { + args.ptr_A, + args.dA, + args.ptr_B, + args.dB, + args.runtime_data_type_a, + args.runtime_data_type_b + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for CpAsync.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.ptr_A), make_shape(M,K,L), params.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), make_shape(N,K,L), params.dB); //(n,k,l) + // Partition for cpasync + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Build the coordinate tensors with the same shape as input matrices + Tensor cA_mk = make_identity_tensor(make_shape(M,K)); + Tensor cB_nk = make_identity_tensor(make_shape(N,K)); + + // Slice the coordinate tensors in the same way as A/B tensor partitioning + Tensor cgA_mk = local_tile(cA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), LoadSmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); + + GmemTiledCopyA gmem_to_smem_a_tiled_copy; + GmemTiledCopyB gmem_to_smem_b_tiled_copy; + + int thread_idx = threadIdx.x % NumLoadThreads; + auto thr_copy_a = gmem_to_smem_a_tiled_copy.get_slice(thread_idx); + auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // gmem + cgA_mk, cgB_nk, // crd + sA, sB, // smem + problem_shape_MNKL, + gmem_to_smem_a_tiled_copy, gmem_to_smem_b_tiled_copy, + thr_copy_a, thr_copy_b); + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + Params const& params, + [[maybe_unused]] cute::tuple, cute::Tensor> const& accumulators_pair, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), MmaSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), MmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + return cute::make_tuple(tiled_mma, tCrA, tCrB); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class CTensorA, class CTensorB, + class STensorA, class STensorB, + class ProblemShape_MNKL, + class TiledCopyA, class TiledCopyB, + class ThreadCopyA, class ThreadCopyB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + Params const& params, + MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + // Unpack from load_inputs + GTensorA tAgA_mkl = get<0>(load_inputs); + GTensorB tBgB_nkl = get<1>(load_inputs); + CTensorA cgA_mk = get<2>(load_inputs); + CTensorB cgB_nk = get<3>(load_inputs); + STensorA sA = get<4>(load_inputs); + STensorB sB = get<5>(load_inputs); + ProblemShape_MNKL problem_shape_MNKL = get<6>(load_inputs); + TiledCopyA gmem_to_smem_a_tiled_copy = get<7>(load_inputs); + TiledCopyB gmem_to_smem_b_tiled_copy = get<8>(load_inputs); + ThreadCopyA thr_copy_a = get<9>(load_inputs); + ThreadCopyB thr_copy_b = get<10>(load_inputs); + auto [M,N,K,L] = problem_shape_MNKL; + + // Slice out the work coord from partitioned tensors + Tensor gA_in = tAgA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + // Repeat slicing out coordinate tensor exactly the same as input tensor does + Tensor cgA_mk_in = cgA_mk(_, _, get<0>(cta_coord_mnkl), _); + Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + + auto k_residue = K - size<1>(gB_in) * size<2>(gA_in); + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + Tensor gA = domain_offset(make_coord(0, k_residue, 0), gA_in); + Tensor gB = domain_offset(make_coord(0, k_residue, 0), gB_in); + + Tensor cA = domain_offset(make_coord(0, k_residue, 0), cgA_mk_in); + Tensor cB = domain_offset(make_coord(0, k_residue, 0), cgB_nk_in); + + auto tAgA = thr_copy_a.partition_S(gA); + auto tAsA = thr_copy_a.partition_D(sA); + + auto tBgB = thr_copy_b.partition_S(gB); + auto tBsB = thr_copy_b.partition_D(sB); + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + Tensor tAcA = thr_copy_a.partition_S(cA); + Tensor tBcB = thr_copy_b.partition_S(cB); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + + // Repeating on predicators with the same operations on tAgA and tBgB + Tensor tAcAk = tAcA(_,_,_,*k_tile_iter); + Tensor tBcBk = tBcB(_,_,_,*k_tile_iter); + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = elem_less(get<0>(tAcAk(0,m,0)), M); // blk_m coord < M + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = elem_less(get<0>(tBcBk(0,n,0)), N); // blk_n coord < N + } + + // 0-th stage with predication on k to account for residue + // For performance consideration, + // this predicated block for K-tail is only activated when there is k-residue + if (k_residue != 0 && k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if ( int(get<1>(tAcAk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_a_tiled_copy, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,write_stage)); + } + else { + clear(tAsA(_,_,k,write_stage)); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (int(get<1>(tBcBk(0,0,k))) >= 0) { // blk_k coord < K + copy_if(gmem_to_smem_b_tiled_copy, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK mainloop_pipe_producer_state + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + + // Advance mainloop_pipe_producer_state + ++mainloop_pipe_producer_state; + } + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + auto mainloop_pipe_producer_state_curr = mainloop_pipe_producer_state; + ++mainloop_pipe_producer_state; + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state_curr, barrier_token); + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state_curr.index(); + + copy_if(gmem_to_smem_a_tiled_copy, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + mainloop_pipeline.producer_commit(mainloop_pipe_producer_state_curr, cutlass::arch::cpasync_barrier_arrive); + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline mainloop_pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB + > + CUTLASS_DEVICE auto + mma(MainloopPipeline mainloop_pipeline, + MainloopPipelineState mainloop_pipe_consumer_state, + cute::tuple, cute::Tensor> const& accumulators_pair, + cute::tuple const& mma_inputs, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state); + + int read_stage = mainloop_pipe_consumer_state.index(); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state); + --k_tile_count; + ++mainloop_pipe_consumer_state; + } + + return mainloop_pipe_consumer_state; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp index e76818d4..047d9b98 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp @@ -347,20 +347,20 @@ struct CollectiveMma< template< class KTileCount, - class GTensorPartitionedScaleA, class GTensorPartitionedScaleB, - class IdentTensorPartitionedScaleA, class IdentTensorPartitionedScaleB, + class GTensorScaleA, class GTensorScaleB, + class IdentTensorScaleA, class IdentTensorScaleB, class STensorScaleA, class STensorScaleB > struct LoadSFParams { // for scheduler KTileCount k_tiles; - GTensorPartitionedScaleA tSFAgSFA_mkl; - GTensorPartitionedScaleB tSFBgSFB_nkl; - IdentTensorPartitionedScaleA tSFAIdentSFA_mkl; - IdentTensorPartitionedScaleB tSFBIdentSFB_nkl; - STensorScaleA tSFAsSFA; - STensorScaleB tSFBsSFB; + GTensorScaleA gSFA_mkl; + GTensorScaleB gSFB_nkl; + IdentTensorScaleA identSFA_mkl; + IdentTensorScaleB identSFB_nkl; + STensorScaleA sSFA; + STensorScaleB sSFB; LayoutSFA layout_SFA; LayoutSFB layout_SFB; @@ -368,14 +368,14 @@ struct CollectiveMma< CUTLASS_DEVICE LoadSFParams ( KTileCount k_tiles_, - GTensorPartitionedScaleA tSFAgSFA_mkl_, GTensorPartitionedScaleB tSFBgSFB_nkl_, - IdentTensorPartitionedScaleA tSFAIdentSFA_mkl_, IdentTensorPartitionedScaleB tSFBIdentSFB_nkl_, - STensorScaleA tSFAsSFA_, STensorScaleB tSFBsSFB_, + GTensorScaleA gSFA_mkl_, GTensorScaleB gSFB_nkl_, + IdentTensorScaleA identSFA_mkl_, IdentTensorScaleB identSFB_nkl_, + STensorScaleA sSFA_, STensorScaleB sSFB_, LayoutSFA layout_SFA_, LayoutSFB layout_SFB_) : k_tiles(k_tiles_) - , tSFAgSFA_mkl(tSFAgSFA_mkl_), tSFBgSFB_nkl(tSFBgSFB_nkl_) - , tSFAIdentSFA_mkl(tSFAIdentSFA_mkl_), tSFBIdentSFB_nkl(tSFBIdentSFB_nkl_) - , tSFAsSFA(tSFAsSFA_), tSFBsSFB(tSFBsSFB_) + , gSFA_mkl(gSFA_mkl_), gSFB_nkl(gSFB_nkl_) + , identSFA_mkl(identSFA_mkl_), identSFB_nkl(identSFB_nkl_) + , sSFA(sSFA_), sSFB(sSFB_) , layout_SFA(layout_SFA_), layout_SFB(layout_SFB_) {} }; @@ -732,35 +732,16 @@ struct CollectiveMma< static_assert(rank(decltype(gSFA_mkl){}) == 5); static_assert(rank(decltype(gSFB_nkl){}) == 5); - // 1 thread copies entire set of scalar - GmemTiledCopySFA scale_copy_a{}; - GmemTiledCopySFB scale_copy_b{}; - - ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); - ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); - Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) - Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) - Tensor tSFAIdentSFA_mkl = thr_scale_copy_a.partition_S(identSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) - - Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); - - Tensor tSFBgSFB_nkl = thr_scale_copy_b.partition_S(gSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) - Tensor tSFBIdentSFB_nkl = thr_scale_copy_b.partition_S(identSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) - Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); - - static_assert(rank(decltype(tSFAgSFA_mkl){}) == 6); - static_assert(rank(decltype(tSFBgSFB_nkl){}) == 6); - LoadSFParams load_params { size<3>(gSFA_mkl), - tSFAgSFA_mkl, tSFBgSFB_nkl, // for input scale tensor values - tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, // for predicating scale tensor copies - tSFAsSFA, tSFBsSFB, // for scale tensor values + gSFA_mkl, gSFB_nkl, // for input scale tensor values + identSFA_mkl, identSFB_nkl, // for predicating scale tensor copies + sSFA, sSFB, // for scale tensor values mainloop_params.layout_SFA, // for predicating scale tensor copies mainloop_params.layout_SFB // for predicating scale tensor copies }; @@ -922,24 +903,44 @@ struct CollectiveMma< KTileIterator k_tile_iter, int k_tile_count) { auto [unused_k_tiles, - tSFAgSFA_mkl, tSFBgSFB_nkl, - tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, - tSFAsSFA, tSFBsSFB, + gSFA_mkl, gSFB_nkl, + identSFA_mkl, identSFB_nkl, + sSFA, sSFB, layout_SFA, layout_SFB] = load_inputs; // slice out the work coord from partitioned tensors GmemTiledCopySFA scale_copy_a{}; GmemTiledCopySFB scale_copy_b{}; - Tensor tSFAgSFA = tSFAgSFA_mkl(_, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor gSFA_k_compact = filter_zeros( + gSFA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl))); // (BLK_M_CPT, BLK_K_CPT, k_cpt) + Tensor gSFB_k_compact = filter_zeros( + gSFB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl))); // (BLK_N_CPT, BLK_K_CPT, k_cpt) - Tensor tSFBgSFB = tSFBgSFB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor identSFA_k_compact = filter_zeros( + identSFA_mkl(_, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)), + gSFA_k_compact.stride()); // (BLK_M_CPT, BLK_K_CPT, k_cpt) + Tensor identSFB_k_compact = filter_zeros( + identSFB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)), + gSFB_k_compact.stride()); // (BLK_N_CPT, BLK_K_CPT, k_cpt) - Tensor thr_tile_SFA_k = tSFAIdentSFA_mkl(_0{}, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); - Tensor thr_tile_pSFA = make_tensor(shape(filter_zeros(thr_tile_SFA_k(_,_,_0{}), tSFAgSFA(_0{},_,_,_0{}).stride()))); - Tensor thr_tile_SFB_k = tSFBIdentSFB_nkl(_0{}, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor sSFA_compact = filter_zeros(sSFA); // (BLK_M_CPT, BLK_K_CPT, P) + Tensor sSFB_compact = filter_zeros(sSFB); // (BLK_N_CPT, BLK_K_CPT, P) - Tensor thr_tile_pSFB = make_tensor(shape(filter_zeros(thr_tile_SFB_k(_,_,_0{}), tSFBgSFB(_0{},_,_,_0{}).stride()))); + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x % size(scale_copy_a)); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x % size(scale_copy_b)); + + Tensor tSFAgSFA_k_compact = thr_scale_copy_a.partition_S(gSFA_k_compact); // (CPY, BLK_M, BLK_K, k) + Tensor tSFAIdentSFA_k_compact = thr_scale_copy_a.partition_S(identSFA_k_compact); // (CPY, BLK_M, BLK_K, k) + + Tensor tSFAsSFA_compact = thr_scale_copy_a.partition_D(sSFA_compact); + + Tensor tSFBgSFB_k_compact = thr_scale_copy_b.partition_S(gSFB_k_compact); // (CPY, BLK_N, BLK_K, k) + Tensor tSFBIdentSFB_k_compact = thr_scale_copy_b.partition_S(identSFB_k_compact); // (CPY, BLK_N, BLK_K, k) + Tensor tSFBsSFB_compact = thr_scale_copy_b.partition_D(sSFB_compact); + + Tensor thr_tile_pSFA = make_fragment_like(tSFAgSFA_k_compact(_0{},_,_,_0{})); + Tensor thr_tile_pSFB = make_fragment_like(tSFBgSFB_k_compact(_0{},_,_,_0{})); // Issue the loads CUTLASS_PRAGMA_NO_UNROLL @@ -949,18 +950,22 @@ struct CollectiveMma< CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFA); ++i) { - Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); - thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); + Tensor tSFAIdentSFA_compact = tSFAIdentSFA_k_compact(_0{},_,_,*k_tile_iter); + thr_tile_pSFA(i) = elem_less(tSFAIdentSFA_compact(i), + shape(filter_zeros(layout_SFA))) && threadIdx.x % 32 < size(scale_copy_a); } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(thr_tile_pSFB); ++i) { - Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); - thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); + Tensor tSFBIdentSFB_compact = tSFBIdentSFB_k_compact(_0{},_,_,*k_tile_iter); + thr_tile_pSFB(i) = elem_less(tSFBIdentSFB_compact(i), + shape(filter_zeros(layout_SFB))) && threadIdx.x % 32 < size(scale_copy_b); } - copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,mainloop_sf_pipe_producer_state.index()))); - copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,mainloop_sf_pipe_producer_state.index()))); + copy_if(scale_copy_a, thr_tile_pSFA, tSFAgSFA_k_compact(_,_,_,*k_tile_iter), + tSFAsSFA_compact(_,_,_,mainloop_sf_pipe_producer_state.index())); + copy_if(scale_copy_b, thr_tile_pSFB, tSFBgSFB_k_compact(_,_,_,*k_tile_iter), + tSFBsSFB_compact(_,_,_,mainloop_sf_pipe_producer_state.index())); mainloop_sf_pipeline.producer_commit(mainloop_sf_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); __syncwarp(); diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp new file mode 100644 index 00000000..5a5bf458 --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -0,0 +1,1296 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/numeric_conversion.h" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/detail/collective/mixed_input_utils.hpp" +#include "cutlass/detail/sm100_mixed_dtype_blockwise_layout.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/mma_sm100.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop for Mixed Input Kernels +template < + int Load2TransformPipelineStageCount_, + int Transform2MmaPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape, + class TileShape_, + class ElementAOptionalTuple_, + class StridePairA_, + class ElementBOptionalTuple_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomsA_, + class CopyAtomsA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomsB_, + class CopyAtomsB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + ClusterShape>, + TileShape_, + ElementAOptionalTuple_, + StridePairA_, + ElementBOptionalTuple_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomsA_, + CopyAtomsA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomsB_, + CopyAtomsB_, + TransformB_> +{ +public: + // + // Type Aliases + // + + using ConversionMode = cutlass::detail::ConversionMode; + // Determine MMA type: MMA_1SM vs MMA_2SM + using AtomThrShapeMNK = Shape(typename TiledMma_::ThrLayoutVMNK{})), _1, _1>; + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedMixedInput< + Load2TransformPipelineStageCount_, + Transform2MmaPipelineStageCount_, + SchedulerPipelineStageCount_, + AccumulatorPipelineStageCount_, + ClusterShape>; + using TileShape = TileShape_; + using TiledMma = TiledMma_; + using KernelSchedule = typename DispatchPolicy::Schedule; + static constexpr bool IsDynamicCluster = not cute::is_static_v; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + using ElementAOptionalTuple = ElementAOptionalTuple_; + using ElementBOptionalTuple = ElementBOptionalTuple_; + +private: + + template friend struct detail::MixedInputUtils; + using CollectiveType = CollectiveMma; + using Utils = detail::MixedInputUtils; + + using ElementScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple_>; + using ElementScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ElementZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ElementZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + +public: + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutScale = cute::remove_cvref_t(StridePairA_{}))>; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || + (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert(( IsATransformed && (sizeof(ElementB) == 2)) || + (!IsATransformed && (sizeof(ElementA) == 2)) || + (cutlass::gemm::detail::is_k_major() && + cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + // Define A and B block shapes for reduced size TMA_LOADs + using CtaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using CtaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = GmemTiledCopyA_; + + using SmemLayoutAtomsA = SmemLayoutAtomsA_; + using SmemLayoutAtomsB = SmemLayoutAtomsB_; + using CopyAtomsA = CopyAtomsA_; + using CopyAtomsB = CopyAtomsB_; + using SmemCopyAtomScale = Copy_Atom; + + using SmemLayoutAtomA = typename SmemLayoutAtomsA::InputLayoutAtom; + using SmemLayoutAtomACompute = typename SmemLayoutAtomsA::ComputeLayoutAtom; + using SmemLayoutAtomB = typename SmemLayoutAtomsB::InputLayoutAtom; + using SmemLayoutAtomBCompute = typename SmemLayoutAtomsB::ComputeLayoutAtom; + + using InputCopyAtomA = typename CopyAtomsA::InputCopyAtom; + using ComputeCopyAtomA = typename CopyAtomsA::ComputeCopyAtom; + using InputCopyAtomB = typename CopyAtomsB::InputCopyAtom; + using ComputeCopyAtomB = typename CopyAtomsB::ComputeCopyAtom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemLayoutAtomACompute = cute::conditional_t; + using InternalSmemLayoutAtomBCompute = cute::conditional_t; + + using InternalInputCopyAtomA = cute::conditional_t; + using InternalInputCopyAtomB = cute::conditional_t; + using InternalComputeCopyAtomA = cute::conditional_t; + using InternalComputeCopyAtomB = cute::conditional_t; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealSwappedElementA = cute::conditional_t; + using RealSwappedElementB = cute::conditional_t; + using SwappedElementA = cute::conditional_t; + using SwappedElementB = cute::conditional_t; + using SwappedStrideA = cute::conditional_t; + using SwappedStrideB = cute::conditional_t; + using InternalSwappedStrideA = cute::conditional_t; + using InternalSwappedStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + using TmaElementScale = uint_bit_t >; // in case we have array. translating to uint to satisfy tma descriptor's specialization + + using ArchTag = typename DispatchPolicy::ArchTag; + static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, + "Compute type A should be cutlass::bfloat16_t or cutlass::half_t or cutlass::float_e4m3_t"); + + using Load2TransformPipeline = cutlass::PipelineTmaTransformAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + AtomThrShapeMNK>; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Load2MmaPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Load2TransformPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using Load2MmaPipelineState = typename Load2MmaPipeline::PipelineState; + + using Transform2MmaPipeline = cutlass::PipelineUmmaConsumerAsync< + DispatchPolicy::Transform2MmaPipelineStageCount, + AtomThrShapeMNK>; + using Transform2MmaPipelineState = typename Transform2MmaPipeline::PipelineState; + + using Mma2AccumPipeline = cutlass::PipelineUmmaAsync< + DispatchPolicy::Schedule::AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2AccumPipelineState = typename Mma2AccumPipeline::PipelineState; + + + static constexpr int ScaleGranularityMN = size<0,0>(LayoutScale{}); + static constexpr int ScaleGranularityK = size<1,0>(LayoutScale{}); + using ScaleConfig = cutlass::detail::Sm100MixedInputBlockwiseScaleConfig< + ScaleGranularityMN, + ScaleGranularityK>; + + using ScaleTileShape = cute::conditional_t(TileShape{}), size<2>(TileShape{}))), + decltype(make_shape(size<1>(TileShape{}), size<2>(TileShape{})))>; + + static constexpr int ScaleTileShape_MN = get<0>(ScaleTileShape{}); + + static constexpr int ScaleK = get<1>(ScaleTileShape{}) / ScaleGranularityK; + + using SmemLayoutAtomScale = decltype(ScaleConfig::smem_atom_layout_scale(ScaleTileShape{})); + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Thread Counts + static constexpr uint32_t NumTransformationThreads = 128; + static constexpr uint32_t NumAccumThreads = 128; //Maintains compatibility with input_transform kernel + + // Get the Algorithm parameters + constexpr static int AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + constexpr static int StagesPerTile = size<2>(CtaShapeA_MK{}); + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeA_MK{}) * size<1>(CtaShapeA_MK{})) % size<0>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeA_MK{}) * size<2>(CtaShapeA_MK{})) % size<1>(SmemLayoutAtomACompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert(((size<0,0>(CtaShapeB_NK{}) * size<1>(CtaShapeB_NK{})) % size<0>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + static_assert(((size<0,1>(CtaShapeB_NK{}) * size<2>(CtaShapeB_NK{})) % size<1>(SmemLayoutAtomBCompute{})) == 0, "SmemLayoutAtomCompute must evenly divide tile shape."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutACompute = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomACompute{}, + append(CtaShapeA_MK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(CtaShapeB_NK{}, Int{}), + (cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}))); + + using SmemLayoutScale = decltype(make_layout( + append(shape(SmemLayoutAtomScale{}), Int{}), + append(stride(SmemLayoutAtomScale{}), size(filter_zeros(SmemLayoutAtomScale{}))) + )); + + static_assert(DispatchPolicy::Load2TransformPipelineStageCount >= 2 && DispatchPolicy::Load2TransformPipelineStageCount >= 2, + "Specialization requires Stages set to value 2 or more."); + static_assert((cute::is_base_of::value || + cute::is_base_of::value ) && + cute::is_base_of::value, + "MMA atom must A operand from SMEM or TMEM and B operand from SMEM for this mainloop."); + static_assert((cute::is_same_v || cute::is_same_v), + "GmemTiledCopyA - invalid TMA copy atom specified."); + +private: + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + +public: + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && + cutlass::detail::is_Array_v; + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct PipelineStorage { + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + alignas(16) Load2TransformPipelineStorage load2transform_pipeline; + using Load2MmaPipelineStorage = typename Load2MmaPipeline::SharedStorage; + alignas(16) Load2MmaPipelineStorage load2mma_pipeline; + using Transform2MmaPipelineStorage = typename Transform2MmaPipeline::SharedStorage; + alignas(16) Transform2MmaPipelineStorage transform2mma_pipeline; + using Mma2AccumPipelineStorage = typename Mma2AccumPipeline::SharedStorage; + alignas(16) Mma2AccumPipelineStorage mma2accum_pipeline; + }; + + struct SharedStorage { + static constexpr int scale_elements = Utils::elements_per_smem_scale(); + static constexpr int zero_elements = Utils::elements_per_smem_zero(); + struct TensorStorage : cute::aligned_struct<128, _0> { + + struct TensorStorageUntransformed { + alignas(512) cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + }; + + struct TensorStorageTransformedAinSmem { + // We require alignas(1024) here because the smem_ACompute may not be aligned to 1024 by default. + // We need 1024B alignment of smem_ACompute because we are using Swizzle<3,4,3> here. + // The Swizzle<3,4,3> aligns with 1024B. If we don't align the data, the compiler cannot deduce + // the base pointer of the data. + // This alignment allows us to perform the function swizzle(layout(i) * base_ptr). + alignas(1024) cute::ArrayEngine> smem_ACompute; + }; + + union TensorStorageTransformedAinTmem { + cute::ArrayEngine smem_ACompute; // No smem_ACompute + }; + + using TensorStorageTransformed = cute::conditional_t< + cute::is_base_of::value, + TensorStorageTransformedAinSmem, + TensorStorageTransformedAinTmem>; + + TensorStorageUntransformed input; + TensorStorageTransformed compute; + } tensors; + + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Different from other GEMM kernels, both CTAs should be aware of loads. Both CTAs will work on + // loaded input A and B matrices to convert the data type + static constexpr uint32_t TmaTransactionBytes_A = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + Utils::compute_tma_transaction_bytes_extra_transform(); + static constexpr uint32_t TmaTransactionBytes_B = cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytes_A + TmaTransactionBytes_B; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + StrideA dA{}; + ElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementScale const* ptr_S{nullptr}; + LayoutScale layout_S{}; + ElementZero const* ptr_Z{nullptr}; + }; + + struct TMAScaleParams { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_Scale = decltype(make_tma_atom( + GmemTiledCopyScale{}, + make_tensor(static_cast(nullptr), LayoutScale{}), + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + size<2>(ClusterLayout_VMNK{})) + ); + + TMA_Scale tma_load_scale; + TMA_Scale tma_load_zero; + + }; + + struct EmptyScaleParams {}; + + // Device side kernel params + struct Params : public cute::conditional_t { + + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + + uint32_t tma_transaction_bytes{TmaTransactionBytes}; + SwappedStrideA dA{}; + SwappedStrideB dB{}; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensor_a = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + // Cluster layout for TMA construction + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + uint32_t tma_transaction_bytes = TmaTransactionBytes; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return { + {}, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else if constexpr (ModeHasScales) { + ElementScale const* ptr_S = args.ptr_S; + + Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), args.layout_S); + typename Params::TMA_Scale tma_load_scale = make_tma_atom( + GmemTiledCopyScale{}, + tensor_scale, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + size<2>(cluster_layout_vmnk) + ); + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale) { + typename Params::TMAScaleParams scale_params{tma_load_scale, {}}; + return { + scale_params, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(detail::get_logical_ptr(args.ptr_Z), args.layout_S); + typename Params::TMA_Scale tma_load_zero = make_tma_atom( + GmemTiledCopyScale{}, + tensor_zero, + SmemLayoutScale{}(_,_,cute::Int<0>{}), + ScaleTileShape{}, + size<2>(cluster_layout_vmnk)); + + typename Params::TMAScaleParams scale_params{tma_load_scale, tma_load_zero}; + return { + scale_params, + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + tma_transaction_bytes, + args.dA, args.dB }; + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_S = cutlass::detail::get_input_alignment_bits(); + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cutlass::sizeof_bits::value; + bool check_aligned_A = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cutlass::sizeof_bits::value; + bool check_aligned_B = cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + bool check_aligned_S = true; + bool check_aligned_Z = true; + bool check_mode_args = true; + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + check_mode_args = check_mode_args && (args.ptr_S == nullptr); + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (ModeHasScales) { + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_S = cutlass::detail::check_alignment(args.layout_S); + check_mode_args = check_mode_args && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + check_mode_args = check_mode_args && (args.ptr_Z == nullptr); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits_S / cutlass::sizeof_bits::value; + check_aligned_Z = cutlass::detail::check_alignment(args.layout_S); + check_mode_args = check_mode_args && (args.ptr_Z != nullptr); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!check_mode_args) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Invalid arguments for the selected conversion mode.\n"); + } + if (!check_aligned_A) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor A meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_B) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor B meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_S) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor S (scale) meet the minimum alignment requirements for TMA.\n"); + } + if (!check_aligned_Z) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Tensor Z (zeros) meet the minimum alignment requirements for TMA.\n"); + } + + return check_mode_args && check_aligned_A && check_aligned_B && check_aligned_S && check_aligned_Z; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + } + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert); + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_zero.get_tma_descriptor()); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator, + class... Ts + > + CUTLASS_DEVICE auto + load_A( + Params const& params, + Load2TransformPipeline load2xform_pipeline, + Load2TransformPipelineState load2xform_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, extra_input_partitions] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + //Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // LOCK mainloop_load2xform_pipeline_state for _writing_ + load2xform_pipeline.producer_acquire(load2xform_pipeline_state, load2xform_pipeline_flag); + + int tile_A_write_stage = load2xform_pipeline_state.index(); + + BarrierType* load2xform_tma_barrier = load2xform_pipeline.producer_get_barrier(load2xform_pipeline_state); + + // Advance mainloop load2transform pipeline + ++load2xform_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2xform_pipeline_flag = load2xform_pipeline.producer_try_acquire(load2xform_pipeline_state, skip_wait); + + // TMA load for A k_tile + copy(observed_tma_load_a_->with(*load2xform_tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,tile_A_write_stage)); + + if constexpr (ModeHasScales) { + auto tSgS_mkl = get<0>(extra_input_partitions); + auto tSgS = tSgS_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tSsS = get<1>(extra_input_partitions); + copy(params.tma_load_scale.with(*load2xform_tma_barrier, mcast_mask_a), tSgS(_,*k_tile_iter), tSsS(_,tile_A_write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ_mkl = get<2>(extra_input_partitions); + auto tZgZ = tZgZ_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + auto tZsZ = get<3>(extra_input_partitions); + copy(params.tma_load_zero.with(*load2xform_tma_barrier, mcast_mask_a), tZgZ(_,*k_tile_iter), tZsZ(_,tile_A_write_stage)); + } + } + else { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert); + else static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + } + + + return cute::make_tuple(load2xform_pipeline_state, k_tile_iter); + + } + + /// Produce the inputs to the transform threads by loading inputs from gmem -> smem + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator, + class... Ts + > + CUTLASS_DEVICE auto + load_B( + Params const& params, + Load2MmaPipeline load2mma_pipeline, + Load2MmaPipelineState load2mma_pipeline_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_gA, unused_gB, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + mcast_mask_a, mcast_mask_b, extra_input_partitions] = load_inputs; + + // slice out the work coord from tiled tensors + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + //Load2Mma and Load2Transform pipelines both have the same ProducerBarrierType + using BarrierType = typename Load2TransformPipeline::ProducerBarrierType; + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // LOCK mainloop_load2mma_pipeline_state for _writing_ + load2mma_pipeline.producer_acquire(load2mma_pipeline_state, load2mma_pipeline_flag); + + int tile_B_write_stage = load2mma_pipeline_state.index(); + + BarrierType* load2mma_tma_barrier = load2mma_pipeline.producer_get_barrier(load2mma_pipeline_state); + + // Advance mainloop load2mma pipeline + ++load2mma_pipeline_state; + + skip_wait = (k_tile_count <= 1); + load2mma_pipeline_flag = load2mma_pipeline.producer_try_acquire(load2mma_pipeline_state, skip_wait); + + // TMA load for B k_tile + copy(observed_tma_load_b_->with(*load2mma_tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,tile_B_write_stage)); + + ++k_tile_iter; + } + + return cute::make_tuple(load2mma_pipeline_state, k_tile_iter); + + } + + /// Set up the data needed by this collective for load. + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tiled tensor for input A + /// gB_nkl - The tiled tensor for input B + // Other inputs needed for load(): partitioned AB tensors for gmem and smem, and mcast masks + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_storage) const { + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple()); + } + else if constexpr (ModeHasScales) { + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor mS_mkl = params.tma_load_scale.get_tma_tensor(shape(LayoutScale{})); + Tensor gS_mkl = local_tile(mS_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + + Tensor tCgS_mkl = cta_mma.partition_A(gS_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCsS = cta_mma.partition_A(sS); + + // Project the cta_layout for tma_scale along the n-modes + auto [tSgS_mkl, tSsS] = tma_partition(params.tma_load_scale, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsS), group_modes<0,3>(tCgS_mkl)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS)); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = params.tma_load_scale.get_tma_tensor(shape(LayoutScale{})); + Tensor gZ_mkl = local_tile(mS_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + Tensor tCgZ_mkl = cta_mma.partition_A(gZ_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + + Tensor tCsZ = cta_mma.partition_A(sZ); + // Project the cta_layout for tma_scale along the n-modes + auto [tZgZ_mkl, tZsZ] = tma_partition(params.tma_load_zero, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsZ), group_modes<0,3>(tCgZ_mkl)); + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + cute::make_tuple(tSgS_mkl, tSsS, tZgZ_mkl, tZsZ)); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + + } + + template< + class KTileIterator, class Accumulator, + class GTensorA, class DstCopyA, class SrcTensorA, class DstTensorA, + class... Ts + > + CUTLASS_DEVICE auto + transform( + Load2TransformPipeline load2transform_pipeline, + Load2TransformPipelineState load2transform_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_producer_state, + Accumulator accumulators, + cute::tuple> input_operands, + KTileIterator k_tile_iter, int k_tile_count) { + + static_assert(cute::is_same_v, "ElementAMma and ElementBMma types should be the same."); + cutlass::arch::NamedBarrier transform_bar(NumTransformationThreads, cutlass::arch::ReservedNamedBarriers::TransformBarrier); + + // tAsA : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM) + // tAsACompute : (Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest, SmemStages (In SMEM or TMEM) + auto [unused_tAgA, dst_copy_A, tAsA, tAsACompute, + partitioned_extra_info] = input_operands; + + // Create the tensors in registers + auto tArA = make_tensor(tAsA(_,_,_,_,0).shape()); //(Copy,#Copy),MMA_Rest,MMA_M_Rest,MMA_K_Rest (Register) + auto tArACompute = make_tensor(tAsA(_,_,_,_,0).shape()); + constexpr int K_BLOCK_MAX = size<3>(tArA); + + uint32_t skip_wait = (k_tile_count <= 0); + auto load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + auto transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2transform_pipeline.consumer_wait(load2transform_pipeline_consumer_state, load2transform_flag); + + transform2mma_pipeline.producer_acquire(transform2mma_pipeline_producer_state, transform2mma_flag); + + int load2transform_consumer_index = load2transform_pipeline_consumer_state.index(); // read stage + int transform2mma_producer_index = transform2mma_pipeline_producer_state.index(); //write stage + + auto curr_load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state; + + // Copy the input A matrix from SMEM + copy(AutoVectorizingCopy{}, tAsA(_,_,_,_,load2transform_consumer_index), tArA); + // Copy scale/zero vector from SMEM + Utils::copy_scale_zeros_for_transform(partitioned_extra_info, load2transform_consumer_index); + + // Loads from SMEM are done. Signal the mainloop load as early as possible + transform_bar.sync(); + load2transform_pipeline.consumer_release(curr_load2transform_pipeline_consumer_state); + + auto curr_transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state; + + // Dequantize A with scale/zero in RF + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; k_block ++){ + Utils::dequantize_A_kblock_for_transform(tArA, tArACompute, partitioned_extra_info, k_block); + } + + // Dequantized A is stored into either Smem or Tmem + copy(dst_copy_A, tArACompute, tAsACompute(_,_,_,_,transform2mma_producer_index)); + + // fence for SMEM writes + cutlass::arch::fence_view_async_shared(); + if constexpr (is_tmem::value) { + // fence for TMEM writes if A operand is coming from TMEM + cutlass::arch::fence_view_async_tmem_store(); + } + + // Let the MMA know we are done transforming + transform2mma_pipeline.producer_commit(curr_transform2mma_pipeline_producer_state); + // Next pipeline stage + ++load2transform_pipeline_consumer_state; + ++transform2mma_pipeline_producer_state; + + skip_wait = (k_tile_count <= 1); + // Peek the next pipeline stage's barriers + load2transform_flag = load2transform_pipeline.consumer_try_wait(load2transform_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.producer_try_acquire(transform2mma_pipeline_producer_state, skip_wait); + } + return cute::make_tuple(load2transform_pipeline_consumer_state, transform2mma_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + transform_init( + Params const& params, + ProblemShape_MNKL const& problem_shape_MNKL, + Accumulator accumulators, + TensorStorage& shared_storage) { + + auto [gA_mkl, gB_nkl] = tile_input_tensors(params, problem_shape_MNKL); + + Tensor sA_orig = make_tensor(make_smem_ptr(shared_storage.input.smem_A.begin()), SmemLayoutA{}); + Tensor sA = as_position_independent_swizzle_tensor(sA_orig); + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + + Tensor sS = make_tensor(make_smem_ptr(shared_storage.input.smem_scale.begin()), SmemLayoutScale{}); + Tensor sZ = make_tensor(make_smem_ptr(shared_storage.input.smem_zero.begin()), SmemLayoutScale{}); + + // Map input, compute, and fragment tensors to + // Copy strategies and partitioned tensors. These will become the input + // operands of the transform function. Depending on MMA atom type, the + // operands can reside in SMEM or TMEM + auto setup_copy_ops = [&] ( + auto tensor_input, + auto input_copy_atom, + auto tensor_compute, + auto make_fragment, + auto compute_copy_atom) constexpr { + auto fragment_compute = make_fragment(tensor_compute); + if constexpr (cute::is_tmem>::value) { + // For M=128 with 2CTA MMA atoms, the TMEM tensor for A has a duplicated allocation. + // Instead of allocation a 64x16 TMEM tensor, we have a 128x16 allocation + // See: TmemAllocMode::Duplicated. + Tensor tensor_input2x = [&] () constexpr { + if constexpr (decltype(size<0,0>(fragment_compute) == Int<128>{} && size<0,0>(tensor_input) == Int<64>{})::value) { + return make_tensor(tensor_input.data(), + logical_product(tensor_input.layout(), + make_tile(make_tile(Layout<_2,_0>{},_),_,_,_))); + } + else { + return tensor_input; + } + }(); + + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + // If operand comes from TMEM, create the TMEM_STORE based copy + auto r2t_tiled_copy = make_tmem_copy(compute_copy_atom, fragment_compute(_,_,_,0)); + auto thr_r2t_tiled_copy = r2t_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_r2t_tiled_copy.partition_S(tensor_input2x); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + auto partitioned_tensor_compute = thr_r2t_tiled_copy.partition_D(fragment_compute); //(TMEM_STORE, TMEM_STORE_M, TMEM_STORE_N) + + // Source copy is based on the source operand of TMEM_STORE copy. + auto smem2reg_tiled_copy = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + return cute::make_tuple(smem2reg_tiled_copy, r2t_tiled_copy, partitioned_tensor_input, partitioned_tensor_compute); + } + else { + auto tensor_compute_ind_sw = as_position_independent_swizzle_tensor(tensor_compute); + auto r2s_tiled_copy = make_cotiled_copy(compute_copy_atom, Layout, Stride< _8,_1>>{}, + tensor_compute(_,_,_,0).layout()); + + auto smem2reg_tiled_copy = make_tiled_copy_S(input_copy_atom, r2s_tiled_copy); + auto thr_r2s_tiled_copy = r2s_tiled_copy.get_slice(threadIdx.x % NumTransformationThreads); + auto partitioned_tensor_input = thr_r2s_tiled_copy.partition_S(tensor_input); //(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + auto partitioned_tensor_compute = thr_r2s_tiled_copy.partition_D(tensor_compute_ind_sw);//(SMEM_STORE, SMEM_STORE_M, SMEM_STORE_N) + + + return cute::make_tuple(smem2reg_tiled_copy, AutoVectorizingCopy{}, partitioned_tensor_input, partitioned_tensor_compute); + } + }; + + auto [src_copy_A, dst_copy_A, tAsA, tAsACompute] = + setup_copy_ops(sA, InputCopyAtomA{}, sACompute, [&](auto &arg) {return TiledMma::make_fragment_A(arg);}, ComputeCopyAtomA{}); + + // Partition of thread -> shared and thread -> RF + auto fragment_compute = TiledMma::make_fragment_A(sACompute); + fragment_compute.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto r2t_tiled_copy = make_tmem_copy(ComputeCopyAtomA{}, fragment_compute(_,_,_,0)); + auto src_copy_scale = make_tiled_copy_S(Copy_Atom{}, r2t_tiled_copy); + + auto partitioned_extra_info = Utils::partition_extra_transform_info(TiledMma{}, src_copy_scale, shared_storage); + + return cute::make_tuple(gA_mkl, dst_copy_A, tAsA, tAsACompute, + partitioned_extra_info); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgEngine, class FrgLayout, + class TensorA, class TensorB + > + CUTLASS_DEVICE auto + mma( + Load2MmaPipeline load2mma_pipeline, + Load2MmaPipelineState load2mma_pipeline_consumer_state, + Transform2MmaPipeline transform2mma_pipeline, + Transform2MmaPipelineState transform2mma_pipeline_consumer_state, + Mma2AccumPipeline mma2accum_pipeline, + Mma2AccumPipelineState mma2accum_pipeline_producer_state, + cute::Tensor const& accumulators, + cute::tuple const& input_operands, + int k_tile_count + ) { + TiledMma tiled_mma; + + auto curr_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + auto next_load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state; + + auto curr_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + auto next_transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state; + + uint32_t skip_wait = (k_tile_count <= 0); + auto transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + auto load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + ++next_transform2mma_pipeline_consumer_state; + ++next_load2mma_pipeline_consumer_state; + + + // tCrA : (MMA), MMA_M, MMA_K, SmemStage (In SMEM or TMEM) + // We use SMEM stages to match #buffers in Load <-> Convert + // tCrB : (MMA), MMA_N, MMA_K, SmemStages (In SMEM) + auto const [tCrA, tCrB] = input_operands; + + mma2accum_pipeline.producer_acquire(mma2accum_pipeline_producer_state); + + int mma2accum_pipeline_producer_state_index = mma2accum_pipeline_producer_state.index(); + auto tCtC = accumulators(_,_,_,mma2accum_pipeline_producer_state_index); + auto curr_mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state; + ++mma2accum_pipeline_producer_state; + + // + // PIPELINED MAIN LOOP + // + // Clear the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + load2mma_pipeline.consumer_wait(curr_load2mma_pipeline_consumer_state, load2mma_flag); + transform2mma_pipeline.consumer_wait(curr_transform2mma_pipeline_consumer_state, transform2mma_flag); + + int load2mma_pipeline_consumer_state_index = curr_load2mma_pipeline_consumer_state.index(); //read_stage + int transform2mma_pipeline_consumer_state_index = curr_transform2mma_pipeline_consumer_state.index(); //read_stage + + auto tCrA0 = tCrA(_,_,_,transform2mma_pipeline_consumer_state_index); + auto tCrB0 = tCrB(_,_,_,load2mma_pipeline_consumer_state_index); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); k_block ++) { + cute::gemm(tiled_mma, tCrA0(_,_,k_block), tCrB0(_,_,k_block), tCtC); // A[0]*B[0] + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + + load2mma_pipeline.consumer_release(curr_load2mma_pipeline_consumer_state); + transform2mma_pipeline.consumer_release(curr_transform2mma_pipeline_consumer_state); + + skip_wait = (k_tile_count <= 1); + load2mma_flag = load2mma_pipeline.consumer_try_wait(next_load2mma_pipeline_consumer_state, skip_wait); + transform2mma_flag = transform2mma_pipeline.consumer_try_wait(next_transform2mma_pipeline_consumer_state, skip_wait); + + curr_load2mma_pipeline_consumer_state = next_load2mma_pipeline_consumer_state; + curr_transform2mma_pipeline_consumer_state = next_transform2mma_pipeline_consumer_state; + + ++next_load2mma_pipeline_consumer_state; + ++next_transform2mma_pipeline_consumer_state; + } + + mma2accum_pipeline.producer_commit(curr_mma2accum_pipeline_producer_state); + + return cute::make_tuple(curr_load2mma_pipeline_consumer_state, curr_transform2mma_pipeline_consumer_state, mma2accum_pipeline_producer_state); + } + + template + CUTLASS_DEVICE auto + mma_init(cute::Tensor const& accumulators, TensorStorage& shared_storage) const { + TiledMma tiled_mma; + + auto get_tCrA = [&] () constexpr { + if constexpr (cute::is_base_of::value) { + Tensor sACompute = make_tensor(make_smem_ptr(shared_storage.compute.smem_ACompute.begin()), SmemLayoutACompute{}); + return tiled_mma.make_fragment_A(sACompute); + } + else { + auto tCrA = tiled_mma.make_fragment_A(shape(SmemLayoutACompute{})); + tCrA.data() = accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + return tCrA; + } + }; + + Tensor tCrA = get_tCrA(); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.input.smem_B.begin()), SmemLayoutB{}); + Tensor tCrB = tiled_mma.make_fragment_B(sB); + return cute::make_tuple(tCrA, tCrB); + } + + template + CUTLASS_DEVICE auto + accum_init(cute::Tensor const& accumulators, TmemCopyAtom tmem_cp_atom, EpilogueTile epilogue_tile) { + return accumulators; + } + +private: + template + CUTLASS_DEVICE + constexpr auto + tile_input_tensors(Params const& params, ProblemShape_MNKL const& problem_shape_MNKL) const { + using X = cute::Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + typename Params::TMA_A const* observed_tma_load_a_ = nullptr; + typename Params::TMA_B const* observed_tma_load_b_ = nullptr; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp new file mode 100644 index 00000000..e90d7278 --- /dev/null +++ b/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp @@ -0,0 +1,1685 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int LoadABPipelineStageCount, + int LoadSFPipelineStageCount, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, int) + cutlass::sm103::detail::KernelPrefetchType PrefetchType, + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>; + + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // Assert that TiledMma and TileShape should be weakly compatible + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TiledMma and TileShape should be weakly compatible"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::round_up(size<1>(CtaShape_MNK{}), Blk_MN{}); + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + static int constexpr SF_BUFFERS_PER_TILE_K = SFVecSize == 16 ? 4 : 2; + using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/SF_BUFFERS_PER_TILE_K>{})); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + using InternalStrideB = cute::remove_pointer_t; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using InternalLayoutSFA = cute::remove_pointer_t; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + using InternalLayoutSFB = cute::remove_pointer_t; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadABPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadSFPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + static_assert(cute::is_void_v, + "SM103 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(cute::is_void_v, + "SM103 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,NUM_PIPES) + using SmemLayoutA_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,3) + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,NUM_PIPES) + using SmemLayoutB_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,3) + + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + using TmaInternalElementA = uint8_t; + using TmaInternalElementB = uint8_t; + + using SmemAllocTypeA = uint8_t; + using SmemAllocTypeB = uint8_t; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + using SmemPrefetchType = uint8_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + cute::TmaDescriptor smem_tensormap_SFA; + cute::TmaDescriptor smem_tensormap_SFB; + } tensormaps; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + struct PipelineStorage { + PipelineABStorage pipeline_ab; + PipelineSFStorage pipeline_sf; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ArrayElementA const** ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const** ptr_B{nullptr}; + StrideB dB{}; + ElementSF const** ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const** ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom( + GmemTiledCopyA{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{})), + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(ClusterShape{})) + ); + + using TMA_B = decltype(make_tma_atom( + GmemTiledCopyB{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{})), + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(ClusterShape{})/size(typename TiledMma::AtomThrID{})) + ); + + using TMA_SFA = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), InternalLayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(ClusterShape{})) + ); + + using TMA_SFB = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), InternalLayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(ClusterShape{})/size(typename TiledMMA_SF::AtomThrID{})) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + cute::TmaDescriptor* tensormaps; + ArrayElementA const** ptr_A; + StrideA dA; + ArrayElementB const** ptr_B; + StrideB dB; + ElementSF const** ptr_SFA; + ElementSF const** ptr_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shapes, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_M = int32_t(size<0>(TileShape{})); + auto init_N = int32_t(size<1>(TileShape{})); + auto init_K = int32_t(size<2>(TileShape{})); + auto init_L = 1; + + // Tensor pointers will be fixed before the first access + ElementA const* ptr_A_first_batch = nullptr; + ElementB const* ptr_B_first_batch = nullptr; + + InternalStrideA stride_a; + InternalStrideB stride_b; + InternalLayoutSFA layout_SFA; + InternalLayoutSFB layout_SFB; + + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(init_M, init_N, init_K, 1)); + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + layout_SFA = args.layout_SFA; + layout_SFB = args.layout_SFB; + } + + // Batches/Groups are managed by using appropriate pointers to input matrices. + Tensor tensor_a = recast(make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,init_L), stride_a))); + Tensor tensor_b = recast(make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,init_L), stride_b))); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + // Tensor pointers will be fixed before the first access + ElementSF const* ptr_SFA_first_batch = nullptr; + ElementSF const* ptr_SFB_first_batch = nullptr; + + Tensor tensor_sfa = make_tensor(ptr_SFA_first_batch, layout_SFA); + Tensor tensor_sfb = make_tensor(ptr_SFB_first_batch, layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape) + ); + + typename Params::TMA_B tma_load_b = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape)/size(typename TiledMma::AtomThrID{}) + ); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape_fallback) + ); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape_fallback)/size(typename TiledMma::AtomThrID{}) + ); + + typename Params::TMA_SFA tma_load_sfa = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape) + ); + + typename Params::TMA_SFB tma_load_sfb = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape)/size(typename TiledMMA_SF::AtomThrID{}) + ); + + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape_fallback) + ); + + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape_fallback)/size(typename TiledMMA_SF::AtomThrID{}) + ); + + #if 0 + print("tma_load_a:\n"); + print(tma_load_a); + print("tma_load_a.tma_desc:\n"); print(tma_load_a.tma_desc_); print("\n"); + + print("tma_load_b:\n"); + print(tma_load_b); + print("tma_load_b.tma_desc:\n"); print(tma_load_b.tma_desc_); print("\n"); + + print("layout_SFA: "); print(args.layout_SFA); print("\n"); + print("tma_load_sfa:\n"); + print(tma_load_sfa); + print("tma_load_sfa.tma_desc:\n"); print(tma_load_sfa.tma_desc_); print("\n"); + + print("layout_SFB: "); print(args.layout_SFB); print("\n"); + print("tma_load_sfb:\n"); + print(tma_load_sfb); + print("tma_load_sfb.tma_desc:\n"); print(tma_load_sfb.tma_desc_); print("\n"); + + print("layout_sfa: "); print(args.layout_SFA); print("\n"); + print("tma_load_sfa_fallback:\n"); + print(tma_load_sfa_fallback); + print("tma_load_sfa_fallback.tma_desc:\n"); print(tma_load_sfa_fallback.tma_desc_); print("\n"); + + print("layout_sfb: "); print(args.layout_SFB); print("\n"); + print("tma_load_sfb_fallback:\n"); + print(tma_load_sfb_fallback); + print("tma_load_sfb_fallback.tma_desc:\n"); print(tma_load_sfb_fallback.tma_desc_); print("\n"); + #endif + + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + reinterpret_cast(workspace), + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_SFA), + reinterpret_cast(args.ptr_SFB) + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 4; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + + bool implementable = true; + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + } + } + + if constexpr (IsRuntimeDataType && detail::is_sm10x_mxf4nvf4_input() && detail::is_sm10x_mxf4nvf4_input()) { + bool is_compatible = (SFVecSize == 16 || + (SFVecSize == 32 && is_same_v + && args.runtime_data_type_a == cute::UMMA::MXF4Format::E2M1 + && args.runtime_data_type_b == cute::UMMA::MXF4Format::E2M1)); + if (!is_compatible) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: 2x mode (VectorSize=32) only supports float_e2m1_t for a/b types and ue8m0_t for sf type.\n"); + } + implementable &= is_compatible; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE auto + slice_accumulator(cute::Tensor const& accumulators, int stage) { + return accumulators(_,_,_,stage); + } + + template + CUTLASS_DEVICE auto + get_mkl_shape_tensor ( + ProblemShape_MNKL const& problem_shape_MNKL) const { + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,mock_L)); + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); + return gA_mkl; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,mock_L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K_recast,mock_L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl_tmp = cta_mma.partition_A(gA_mkl); // ((CTA_MMA_M,96),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor cta_tCgA = make_tensor(tCgA_mkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgA_mkl_tmp), cute::layout<1>(tCgA_mkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgA_mkl_tmp), cute::layout<2>(tCgA_mkl_tmp))), + cute::layout<3>(tCgA_mkl_tmp), cute::layout<4>(tCgA_mkl_tmp), cute::layout<5>(tCgA_mkl_tmp))); // (CTA_M,CTA_K,m,k,l) + + Tensor tCgA_mkl = make_tensor(cta_tCgA.data(), tiled_divide(cta_tCgA.layout(), + make_tile(size<1,0>(typename TiledMma::ALayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M,Rest_MMA_K, m, k, l) + + Tensor tCgB_nkl_tmp = cta_mma.partition_B(gB_nkl); // ((MMA_ATOM_M,96),Rest_MMA_M,Rest_MMA_K, n, k, l) + Tensor cta_tCgB = make_tensor(tCgB_nkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgB_nkl_tmp), cute::layout<1>(tCgB_nkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgB_nkl_tmp), cute::layout<2>(tCgB_nkl_tmp))), + cute::layout<3>(tCgB_nkl_tmp), cute::layout<4>(tCgB_nkl_tmp), cute::layout<5>(tCgB_nkl_tmp))); // (CTA_M,CTA_K,m,k,l) + Tensor tCgB_nkl = make_tensor(cta_tCgB.data(), tiled_divide(cta_tCgB.layout(), + make_tile(size<1,0>(typename TiledMma::BLayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M, Rest_MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_N,32),Rest_MMA_N,8,NUM_PIPE) + + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,1>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,1>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = tensormaps_init_ab(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, int32_t const sm_idx, + int32_t init_group) const { + using X = Underscore; + + // Separate out problem shape for convenience + + InternalLayoutSFA layout_SFA{}; + InternalLayoutSFB layout_SFB{}; + if constexpr (IsGroupedGemmKernel) { + layout_SFA = params.layout_SFA[init_group]; + layout_SFB = params.layout_SFB[init_group]; + } + else { + layout_SFA = params.layout_SFA; + layout_SFB = params.layout_SFB; + } + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(layout_SFB)); + } + }(); + + // Partition for this CTA + Tensor gSFA_mkl = local_tile(mSFA_mkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + Tensor tCgSFA_mkl = make_tensor(gSFA_mkl.data(), tiled_divide(gSFA_mkl.layout(), make_tile(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_M,MMA_K),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor tCgSFB_nkl = make_tensor(gSFB_nkl.data(), tiled_divide(gSFB_nkl.layout(), make_tile(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_N,MMA_K),Rest_MMA_N,Rest_MMA_K, n, k, l) + + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(tCsSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + auto input_tensormaps = tensormaps_init_sf(params, shared_tensormaps, sm_count, sm_idx); + + return cute::make_tuple( + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_sfa, mcast_mask_sfb, // multicast masks + input_tensormaps); // for tma descriptor modification (per-CTA tensormap copy) + } + + /// Set up the data needed by this collective for mma compute. + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TensorStorage& shared_tensors, + uint32_t const tmem_offset) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = make_tensor(sA);; + Tensor tCrB = make_tensor(sB);; + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = make_tensor(take<0,3>(shape(SmemLayoutAtomSFA{}))); + // TMEM allocations for SFA and SFB will always start at DP 0. + tCtSFA.data() = tmem_offset; + Tensor tCtSFB = make_tensor(take<0,3>(shape(SmemLayoutAtomSFB{}))); + + tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); + + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tCtSFA_compact_copy = make_tensor(tCtSFA_compact.data(), append<3>(tCtSFA_compact(_,_0{},_0{}).layout())); + auto tCtSFB_compact_copy = make_tensor(tCtSFB_compact.data(), append<3>(tCtSFB_compact(_,_0{},_0{}).layout())); + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact_copy); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact_copy); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + // using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/2>{})); // 128x128x384 + // MMA shapes are ((_128,_96),_1,_8) which makes the MMA_SFA_Shape ((128, (16,3)), 1, 8/3) + // The number is not divisible by 4 in K dimension which is needed for TMEM allocation. + // To be able to iterate thru the SFs for MMA, we model this as (MMA), MMA_M, MMA_K: ((128, (16,1)), 1, 24) + // with this layout we can iterate thru the SFs by incrementing MMA_K mode by 3/6 for this example (Vs=16 vs Vs=32). + constexpr int MMA_M = size<0>(CtaShape_MNK{}); + constexpr int MMA_N_SF = CTA_N_SF; + constexpr int MMA_K_SF = shape<2>(CtaShape_MNK{}) / 2; + auto mnBasicBlockShape = make_shape(_32{}, _4{}); + auto kBasicBlockShape_single = make_shape(Int{}, Int<1>{}); + auto mma_iter_SFA_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFA_iter_shape = make_shape(mma_iter_SFA_shape, _1{}, Int{}); + auto mma_iter_SFB_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFB_iter_shape = make_shape(mma_iter_SFB_shape, _1{}, Int{}); + + // Used for MMAs + using MmaIterShapeSFA = decltype(sSFA_iter_shape); // ((32,4),(SFVecSize,1), MMA_M/128, SF_MMA_K/SfVecSize + using MmaIterShapeSFB = decltype(sSFB_iter_shape); // ((32,4),(SFVecSize,1), MMA_N/128, SF_MMA_K/SfVecSize + + Tensor tCtSFA_mma = make_tensor(MmaIterShapeSFA{}); + tCtSFA_mma.data() = tCtSFA.data(); + Tensor tCtSFB_mma = make_tensor(MmaIterShapeSFB{}); + tCtSFB_mma.data() = tCtSFB.data(); + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, tCtSFA_mma, tCtSFB_mma, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + +// Helper function to handle both prefetch types + template + CUTLASS_DEVICE void issue_prefetch( + int& prefetch_k_tile_count, + int& prefetch_buf_idx, + KTileIterator& prefetch_k_tile, + TmaPrefetchFn&& tma_prefetch_fn) + { + if (prefetch_k_tile_count > 0) { + if constexpr (PrefetchType == cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch) { + tma_prefetch_fn(); + } + + prefetch_buf_idx = (prefetch_buf_idx + 1) % BuffersPerKtile; + if(prefetch_buf_idx == 0) { + ++prefetch_k_tile; + --prefetch_k_tile_count; + } + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TensorMapA, class TensorMapB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, int prefetch_k_tile_count = 0) { + + auto tAgA_mkl = get<2>(load_inputs); + auto tBgB_nkl = get<3>(load_inputs); + auto tAsA = get<4>(load_inputs); + auto tBsB = get<5>(load_inputs); + auto mcast_mask_a = get<6>(load_inputs); + auto mcast_mask_b = get<7>(load_inputs); + auto input_tensormaps = get<8>(load_inputs); + + if (did_batch_change) { + tensormaps_fence_acquire(get<0>(input_tensormaps)); + tensormaps_fence_acquire(get<1>(input_tensormaps)); + } + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, _, _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + constexpr int BuffersPerKtile = 3; + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadABPipelineStageCount / BuffersPerKtile; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadABPipelineStageCount % BuffersPerKtile; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + // In total, we will load 3 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < BuffersPerKtile; buffer++) { + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + auto tma_copy_traits_a = observed_tma_load_a_->with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a); + auto tma_copy_traits_b = observed_tma_load_b_->with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b); + + if (cute::elect_one_sync()) { + copy(tma_copy_traits_a, group_modes<0,2>(tAgA(_,_,buffer,*k_tile_iter)), tAsA(_,write_stage)); + copy(tma_copy_traits_b, group_modes<0,2>(tBgB(_,_,buffer,*k_tile_iter)), tBsB(_,write_stage)); + } + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(tma_copy_traits_a, group_modes<0,2>(tAgA(_,_,prefetch_buf_idx,*prefetch_k_tile))); + prefetch(tma_copy_traits_b, group_modes<0,2>(tBgB(_,_,prefetch_buf_idx,*prefetch_k_tile))); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TensorMapSFA, class TensorMapSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + Params const& params, + MainloopSFPipeline pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple> const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + bool did_batch_change, int prefetch_k_tile_count = 0) { + + auto tAgSFA_mkl = get<0>(load_inputs); + auto tBgSFB_nkl = get<1>(load_inputs); + auto tAsSFA = get<2>(load_inputs); + auto tBsSFB = get<3>(load_inputs); + auto mcast_mask_sfa = get<4>(load_inputs); + auto mcast_mask_sfb = get<5>(load_inputs); + auto input_tensormaps_sf = get<6>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + // Check to see if tensormaps have been replaced in gmem + if (did_batch_change) { + tensormaps_fence_acquire(get<0>(input_tensormaps_sf)); + tensormaps_fence_acquire(get<1>(input_tensormaps_sf)); + } + + auto barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + + using BarrierType = typename MainloopSFPipeline::ProducerBarrierType; + auto tAsSFA_compact = make_tensor(tAsSFA.data(), filter_zeros(tAsSFA.layout())); + auto tBsSFB_compact = make_tensor(tBsSFB.data(), filter_zeros(tBsSFB.layout())); + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadSFPipelineStageCount / SF_BUFFERS_PER_TILE_K; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadSFPipelineStageCount % SF_BUFFERS_PER_TILE_K; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + // In total, we will load 2 or 4 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < SF_BUFFERS_PER_TILE_K; buffer++) { + pipeline.producer_acquire(mainloop_sf_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_sf_pipe_producer_state); + + int write_stage = mainloop_sf_pipe_producer_state.index(); + ++mainloop_sf_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + auto tAgSFA_compact = make_tensor(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + auto tBgSFB_compact = make_tensor(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + + auto tma_copy_traits_sfa = observed_tma_load_sfa_->with(get<0>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfa); + auto tma_copy_traits_sfb = observed_tma_load_sfb_->with(get<1>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfb); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_sfa_->with(get<0>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfa), tAgSFA_compact, tAsSFA_compact(_,write_stage)); + copy(observed_tma_load_sfb_->with(get<1>(input_tensormaps_sf), *tma_barrier, mcast_mask_sfb), tBgSFB_compact, tBsSFB_compact(_,write_stage)); + } + + auto tAgSFA_compact_prefetch = make_tensor(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + auto tBgSFB_compact_prefetch = make_tensor(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(tma_copy_traits_sfa, tAgSFA_compact_prefetch); + prefetch(tma_copy_traits_sfb, tBgSFB_compact_prefetch); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + template < + class MainloopPipeline, class MainloopPipelineState + > + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class MmaFragmentSFA, class MmaFragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto pipeline_ab = get<0>(pipelines); + auto pipeline_sf = get<1>(pipelines); + auto accumulator_pipeline = get<2>(pipelines); + auto mainloop_pipe_ab_consumer_state = get<0>(pipeline_states); + auto mainloop_pipe_sf_consumer_state = get<1>(pipeline_states); + auto accumulator_pipe_producer_state = get<2>(pipeline_states); + auto tiled_mma = get<0>(mma_inputs); + auto tCrA = get<1>(mma_inputs); + auto tCrB = get<2>(mma_inputs); + auto tCtSFA = get<3>(mma_inputs); + auto tCtSFB = get<4>(mma_inputs); + auto tCtSFA_mma = get<5>(mma_inputs); + auto tCtSFB_mma = get<6>(mma_inputs); + auto tiled_copy_s2t_SFA = get<7>(mma_inputs); + auto tCsSFA_s2t = get<8>(mma_inputs); + auto tCtSFA_s2t = get<9>(mma_inputs); + auto tiled_copy_s2t_SFB = get<10>(mma_inputs); + auto tCsSFB_s2t = get<11>(mma_inputs); + auto tCtSFB_s2t = get<12>(mma_inputs); + + tCtSFB_mma = [tCtSFB_mma = tCtSFB_mma, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB_mma; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB_mma; + } + }(); + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + constexpr int sf_stride = TiledMma::SFVecSize == 16 ? 6 : 3; + auto barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + auto barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state); + constexpr int MmasPerSfBuffer = 8 / SF_BUFFERS_PER_TILE_K; + + auto sf_load_fn = [&](const int kphase, const int k_tile_count) { + if (kphase % MmasPerSfBuffer == 0) { + pipeline_sf.consumer_wait(mainloop_pipe_sf_consumer_state, barrier_token_sf); + int read_stage_sf_buffer0 = mainloop_pipe_sf_consumer_state.index(); + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, tCsSFA_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, tCsSFB_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFB_s2t); + } + auto buffer0_mainloop_pipe_sf_consumer_state = mainloop_pipe_sf_consumer_state; + ++mainloop_pipe_sf_consumer_state; + barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state, (kphase == 8 - MmasPerSfBuffer) && k_tile_count <= 1); // only skip wait for the last one. + pipeline_sf.consumer_release(buffer0_mainloop_pipe_sf_consumer_state); + } + }; + + bool is_first_iteration = true; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // MMA 0 + sf_load_fn(0, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer0 = mainloop_pipe_ab_consumer_state.index(); + auto buffer0_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + // delay the acc acquire to unblock tmem copy. + if constexpr (IsOverlappingAccum) { + if(is_first_iteration) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + is_first_iteration = false; + } + }; + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,0,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,0,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + + // MMA 1 + sf_load_fn(1, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,3,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,3,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + + // MMA 2 + sf_load_fn(2, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer1 = mainloop_pipe_ab_consumer_state.index(); + auto buffer1_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,6,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,6,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer0_mainloop_pipe_ab_consumer_state); + + + // MMA 3 + sf_load_fn(3, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,1,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,1,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 4 + sf_load_fn(4, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,4,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,4,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 5 + sf_load_fn(5, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer2 = mainloop_pipe_ab_consumer_state.index(); + auto buffer2_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state, k_tile_count <= 1); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,7,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,7,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer1_mainloop_pipe_ab_consumer_state); + + // MMA 6 + sf_load_fn(6, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,2,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,2,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + // MMA 7 + sf_load_fn(7, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,5,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,5,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer2_mainloop_pipe_ab_consumer_state); + --k_tile_count; + } + return cute::make_tuple(mainloop_pipe_ab_consumer_state, mainloop_pipe_sf_consumer_state); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + CUTLASS_DEVICE auto + tensormaps_init_ab( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(observed_tma_load_a_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(observed_tma_load_b_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + ElementA const* ptr_A = nullptr; + Tensor tensor_a = recast(make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group])); + + ElementB const* ptr_B = nullptr; + Tensor tensor_b = recast(make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_a_, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_b_, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update_ab( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_ab_tensormaps, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address_ab(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties_ab(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release_ab(shared_tensormaps, input_ab_tensormaps); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release_ab ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_ab_tensormaps) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_ab_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_ab_tensormaps), shared_tensormaps.smem_tensormap_B); + + } + + // SF tensormap ops + CUTLASS_DEVICE auto + tensormaps_init_sf( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t const sm_count, + int32_t const sm_idx) const { + cute::TmaDescriptor* gmem_tensormap = mainloop_params.tensormaps; + + cute::TmaDescriptor* tma_desc_sfa = &gmem_tensormap[sm_idx + 2 * sm_count]; + cute::TmaDescriptor* tma_desc_sfb = &gmem_tensormap[sm_idx + 3 * sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pSFA_tensormap = make_tensor(observed_tma_load_sfa_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFA), Int<1>{}, Int<1>{}); + Tensor pSFB_tensormap = make_tensor(observed_tma_load_sfb_->get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sSFB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_SFB), Int<1>{}, Int<1>{}); + + copy(recast(pSFA_tensormap), recast(sSFA_tensormap)); + copy(recast(pSFB_tensormap), recast(sSFB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_sfa, tma_desc_sfb); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + mainloop_params.ptr_SFA[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + mainloop_params.ptr_SFB[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_SFA = {1,1,1,1,1}; + cute::array prob_stride_SFA = {0,0,0,0,0}; + cute::array prob_shape_SFB = {1,1,1,1,1}; + cute::array prob_stride_SFB = {0,0,0,0,0}; + + ElementSF const* ptr_SF = nullptr; + Tensor tensor_sfa = make_tensor(ptr_SF, mainloop_params.layout_SFA[next_group]); + + Tensor tensor_sfb = make_tensor(ptr_SF, mainloop_params.layout_SFB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfa_, tensor_sfa, + prob_shape_SFA, prob_stride_SFA); + cute::detail::fill_tma_gmem_shape_stride(*observed_tma_load_sfb_, tensor_sfb, + prob_shape_SFB, prob_stride_SFB); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_SFA) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_SFB) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFA, + prob_shape_SFA, + prob_stride_SFA); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_SFB, + prob_shape_SFB, + prob_stride_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_perform_update_sf( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps_sf, + ProblemShape problem_shape, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address_sf(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(next_batch), 1); + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties_sf(shared_tensormaps, + mainloop_params, next_batch, problem_shape_MNKL); + } + } + // Ensure warp is converged before issuing tensormap fence release + __syncwarp(); + // Entire warp must do this (ie its aligned) + tensormaps_cp_fence_release_sf(shared_tensormaps, input_tensormaps_sf); + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release_sf ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps_sf) { + if (cute::elect_one_sync()) { + cute::tma_desc_commit_group(); + cute::tma_desc_wait_group(); + } + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps_sf), shared_tensormaps.smem_tensormap_SFA); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps_sf), shared_tensormaps.smem_tensormap_SFB); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* input_tma_desc) { + cute::tma_descriptor_fence_acquire(input_tma_desc); + } + +protected: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp new file mode 100644 index 00000000..fefd7327 --- /dev/null +++ b/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp @@ -0,0 +1,1276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/sm103_blockscaled_layout.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int LoadABPipelineStageCount, + int LoadSFPipelineStageCount, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, int) + cutlass::sm103::detail::KernelPrefetchType PrefetchType, + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementPairA_, + class StridePairA_, + class ElementPairB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyPairA_, + class SmemLayoutAtomPairA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyPairB_, + class SmemLayoutAtomPairB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>, + TileShape_, + ElementPairA_, + StridePairA_, + ElementPairB_, + StridePairB_, + TiledMma_, + GmemTiledCopyPairA_, + SmemLayoutAtomPairA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyPairB_, + SmemLayoutAtomPairB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm103TmaUmmaWarpSpecializedBlockScaled< + LoadABPipelineStageCount, + LoadSFPipelineStageCount, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape, + PrefetchType>; + + using TileShape = TileShape_; + // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. + using TiledMMA_SF = TiledMMA, + Layout>, + Tile>; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr int SFVecSize = TiledMma::SFVecSize; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // Assert that TiledMma and TileShape should be weakly compatible + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TiledMma and TileShape should be weakly compatible"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, + "Cta N should be one of 128/192/256"); + + using ClusterTileShape = decltype(make_shape(get<0>(TileShape{})*get<0>(ClusterShape{}),get<1>(TileShape{})*get<1>(ClusterShape{}),get<2>(TileShape{})*get<2>(ClusterShape{}))); + using Sm1xxBlkScaledConfig = cutlass::detail::Sm103BlockScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + static constexpr int IsCtaN192 = shape<1>(CtaShape_MNK{}) == 192; + static int constexpr CTA_N_SF = cutlass::round_up(size<1>(CtaShape_MNK{}), Blk_MN{}); + // Tile shape used for partitioning Scale Factor B. + // The M-dim does not affect the SFB, so just set it as the original TileShape; + using TileShape_SF = decltype(make_shape(get<0>(CtaShape_MNK{}), + Int{} * shape<2>(typename TiledMma::ThrLayoutVMNK()), + get<2>(TileShape{}))); + + static int constexpr SF_BUFFERS_PER_TILE_K = SFVecSize == 16 ? 4 : 2; + using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/SF_BUFFERS_PER_TILE_K>{})); + + using ElementPairA = ElementPairA_; + using ElementPairB = ElementPairB_; + using ElementAMma = typename TiledMma::ValTypeA; + using ElementBMma = typename TiledMma::ValTypeB; + using StridePairA = StridePairA_; + using StridePairB = StridePairB_; + using SmemLayoutAtomPairA = SmemLayoutAtomPairA_; + using SmemLayoutAtomPairB = SmemLayoutAtomPairB_; + static_assert(cute::is_same_v(ElementPairA{}))>, + remove_cvref_t(ElementPairB{}))>>, "SFA and SFB data types should be the same"); + + // A and B matrices + using ElementA = remove_cvref_t(ElementPairA{}))>; + using StrideA = remove_cvref_t(StridePairA{}))>; + + using ElementB = remove_cvref_t(ElementPairB{}))>; + using StrideB = remove_cvref_t(StridePairB{}))>; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + // SFA and SFB + using ElementSF = remove_cvref_t(ElementPairA{}))>; + using LayoutSFA = remove_cvref_t(StridePairA{}))>; + using LayoutSFB = remove_cvref_t(StridePairB{}))>; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyPairA = GmemTiledCopyPairA_; + using GmemTiledCopyPairB = GmemTiledCopyPairB_; + using GmemTiledCopyA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopySFA = remove_cvref_t(GmemTiledCopyPairA{}))>; + using GmemTiledCopyB = remove_cvref_t(GmemTiledCopyPairB{}))>; + using GmemTiledCopySFB = remove_cvref_t(GmemTiledCopyPairB{}))>; + + using SmemLayoutAtomA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomSFA = remove_cvref_t(SmemLayoutAtomPairA{}))>; + using SmemLayoutAtomB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + using SmemLayoutAtomSFB = remove_cvref_t(SmemLayoutAtomPairB{}))>; + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopABPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadABPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopABPipelineState = typename MainloopABPipeline::PipelineState; + + using MainloopSFPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::LoadSFPipelineStageCount, + ClusterShape, + AtomThrShapeMNK>; + using MainloopSFPipelineState = typename MainloopSFPipeline::PipelineState; + + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,NUM_PIPES) + using SmemLayoutA_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(make_shape(make_shape(shape<0>(CtaShape_MNK{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_M,16bytes),1,8,3) + + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int{} /*PIPE*/), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,NUM_PIPES) + using SmemLayoutB_tma = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(make_shape(make_shape(shape<1>(CtaShape_MNK{}) / size(typename TiledMma::AtomThrID{}), _16{}), _1{}, _8{}), Int<3>{} /*Per mainloop iteration */), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // ((CTA_MMA_N,16bytes),1,8,3) + + + // SmemLayoutAtomSFA and SmemLayoutAtomSFB are for whole CTA tiles. We add the number of pipeline stages here. + // The number of pipeline stages is the same as the number of pipeline stages from AB Load <-> MainLoop + using SmemLayoutSFA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutSFB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + + using TmaInternalElementA = uint8_t; + using TmaInternalElementB = uint8_t; + + using SmemAllocTypeA = uint8_t; + using SmemAllocTypeB = uint8_t; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = typename detail::sm10x_block_scale_runtime_input_t::Type; + using RuntimeDataTypeB = typename detail::sm10x_block_scale_runtime_input_t::Type; + + using SmemPrefetchType = uint8_t; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_SFA; + cute::ArrayEngine> smem_SFB; + } tensors; + + using PipelineABStorage = typename MainloopABPipeline::SharedStorage; + using PipelineSFStorage = typename MainloopSFPipeline::SharedStorage; + struct PipelineStorage { + PipelineABStorage pipeline_ab; + PipelineSFStorage pipeline_sf; + }; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr uint32_t SFTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t ABTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementSF const* ptr_SFA{nullptr}; + LayoutSFA layout_SFA{}; + ElementSF const* ptr_SFB{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMma::AtomThrID{}))); + using ClusterLayoutSfb_VMNK = + decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), + ClusterShape{})), make_tile(typename TiledMMA_SF::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom( + GmemTiledCopyA{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{})), + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(ClusterShape{})) + ); + + using TMA_B = decltype(make_tma_atom( + GmemTiledCopyB{}, + recast(make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{})), + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(ClusterShape{})/size(typename TiledMma::AtomThrID{})) + ); + + using TMA_SFA = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFA{}, + make_tensor(static_cast(nullptr), LayoutSFA{}), + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(ClusterShape{})) + ); + + using TMA_SFB = decltype(make_tma_atom( // using legacy sm90 make_tma_atom + GmemTiledCopySFB{}, + make_tensor(static_cast(nullptr), LayoutSFB{}), + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(ClusterShape{})/size(typename TiledMMA_SF::AtomThrID{})) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_SFA tma_load_sfa; + TMA_SFB tma_load_sfb; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + TMA_SFA tma_load_sfa_fallback; + TMA_SFB tma_load_sfb_fallback; + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + observed_tma_load_sfa_ = is_fallback_cluster ? ¶ms.tma_load_sfa_fallback : ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = is_fallback_cluster ? ¶ms.tma_load_sfb_fallback : ¶ms.tma_load_sfb; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + observed_tma_load_sfa_ = ¶ms.tma_load_sfa; + observed_tma_load_sfb_ = ¶ms.tma_load_sfb; + + } + } + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = recast(make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA))); + Tensor tensor_b = recast(make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB))); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + + Tensor tensor_sfa = make_tensor(args.ptr_SFA, args.layout_SFA); + Tensor tensor_sfb = make_tensor(args.ptr_SFB, args.layout_SFB); + + typename Params::TMA_A tma_load_a = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape) + ); + typename Params::TMA_B tma_load_b = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape)/size(typename TiledMma::AtomThrID{}) + ); + typename Params::TMA_A tma_load_a_fallback = make_tma_atom( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA_tma{}, + make_tile(size<1,0>(typename TiledMma::ALayout{}), _384{}), + size<1>(cluster_shape_fallback) + ); + typename Params::TMA_B tma_load_b_fallback = make_tma_atom( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB_tma{}, + make_tile(size<1,0>(typename TiledMma::BLayout{}), _384{}), + size<0>(cluster_shape_fallback)/size(typename TiledMma::AtomThrID{}) + ); + typename Params::TMA_SFA tma_load_sfa = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape) + ); + typename Params::TMA_SFB tma_load_sfb = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape)/size(typename TiledMMA_SF::AtomThrID{}) + ); + typename Params::TMA_SFA tma_load_sfa_fallback = make_tma_atom( + GmemTiledCopySFA{}, + tensor_sfa, + SmemLayoutSFA{}(_,_,_,cute::Int<0>{}), + make_shape(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<1>(cluster_shape_fallback) + ); + typename Params::TMA_SFB tma_load_sfb_fallback = make_tma_atom( + GmemTiledCopySFB{}, + tensor_sfb, + SmemLayoutSFB{}(_,_,_,cute::Int<0>{}), + make_shape(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})), + size<0>(cluster_shape_fallback)/size(typename TiledMMA_SF::AtomThrID{}) + ); + return { + tma_load_a, + tma_load_b, + tma_load_sfa, + tma_load_sfb, + tma_load_a_fallback, + tma_load_b_fallback, + tma_load_sfa_fallback, + tma_load_sfb_fallback, + args.layout_SFA, + args.layout_SFB, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if constexpr (IsRuntimeDataType && detail::is_sm10x_mxf4nvf4_input() && detail::is_sm10x_mxf4nvf4_input()) { + bool is_compatible = (SFVecSize == 16 || + (SFVecSize == 32 && is_same_v + && args.runtime_data_type_a == cute::UMMA::MXF4Format::E2M1 + && args.runtime_data_type_b == cute::UMMA::MXF4Format::E2M1)); + if (!is_compatible) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: 2x mode (VectorSize=32) only supports float_e2m1_t for a/b types and ue8m0_t for sf type.\n"); + } + implementable &= is_compatible; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + if constexpr (IsDynamicCluster) { + dim3 cs = cute::cluster_shape(); + const bool is_fallback_cluster = (cs.x == params.cluster_shape_fallback.x && cs.y == params.cluster_shape_fallback.y); + if (is_fallback_cluster) { + cute::prefetch_tma_descriptor(params.tma_load_a_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa_fallback.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb_fallback.get_tma_descriptor()); + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb.get_tma_descriptor()); + } + } + else { + cute::prefetch_tma_descriptor(params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfa.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_sfb.get_tma_descriptor()); + } + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + template + CUTLASS_DEVICE auto + load_ab_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + int K_recast = (K*cute::sizeof_bits_v/8); + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K_recast,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K_recast,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, replace<2>(TileShape{}, _384{}), make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl_tmp = cta_mma.partition_A(gA_mkl); // ((CTA_MMA_M,96),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor cta_tCgA = make_tensor(tCgA_mkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgA_mkl_tmp), cute::layout<1>(tCgA_mkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgA_mkl_tmp), cute::layout<2>(tCgA_mkl_tmp))), + cute::layout<3>(tCgA_mkl_tmp), cute::layout<4>(tCgA_mkl_tmp), cute::layout<5>(tCgA_mkl_tmp))); // (CTA_M,CTA_K,m,k,l) + + Tensor tCgA_mkl = make_tensor(cta_tCgA.data(), tiled_divide(cta_tCgA.layout(), + make_tile(size<1,0>(typename TiledMma::ALayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M,Rest_MMA_K, m, k, l) + + Tensor tCgB_nkl_tmp = cta_mma.partition_B(gB_nkl); // ((MMA_ATOM_M,96),Rest_MMA_M,Rest_MMA_K, n, k, l) + Tensor cta_tCgB = make_tensor(tCgB_nkl_tmp.data(), make_layout(coalesce(make_layout(cute::layout<0,0>(tCgB_nkl_tmp), cute::layout<1>(tCgB_nkl_tmp))), + coalesce(make_layout(cute::layout<0,1>(tCgB_nkl_tmp), cute::layout<2>(tCgB_nkl_tmp))), + cute::layout<3>(tCgB_nkl_tmp), cute::layout<4>(tCgB_nkl_tmp), cute::layout<5>(tCgB_nkl_tmp))); // (CTA_M,CTA_K,m,k,l) + Tensor tCgB_nkl = make_tensor(cta_tCgB.data(), tiled_divide(cta_tCgB.layout(), + make_tile(size<1,0>(typename TiledMma::BLayout{}) /*MMA_M for SM100*/, + _128{} /*128bytes*/))); // ((CTA_MMA_M,256),Rest_MMA_M, Rest_MMA_K, m, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_N,32),Rest_MMA_N,8,NUM_PIPE) + + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,1>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,1>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + return cute::make_tuple( + gA_mkl, gB_nkl, // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + mcast_mask_a, mcast_mask_b // multicast masks + ); + } + + + /// Set up the data needed by this collective for load. + /// Return tuple element contain + /// tAgA_mkl - partitioned gmem tensor for A + /// tBgB_nkl - partitioned gmem tensor for B + /// mcast_mask_sfa - tma multicast mask for SFA + /// mcast_mask_sfb - tma multicast mask for SFB + template + CUTLASS_DEVICE auto + load_sf_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensor of Scale factors + Tensor mSFA_mkl = observed_tma_load_sfa_->get_tma_tensor(shape(params.layout_SFA)); + auto mSFB_nkl = [=](){ + if constexpr (IsCtaN192) { + Tensor mSFB_tmp = observed_tma_load_sfb_->get_tma_tensor(shape(params.layout_SFB)); + auto x = stride<0,1>(mSFB_tmp); + auto y = ceil_div(shape<0,1>(mSFB_tmp), 4); + auto new_shape = make_shape (make_shape( shape<0,0>(mSFB_tmp), + make_shape( make_shape(_2{}, _2{}), y)), shape<1>(mSFB_tmp), shape<2>(mSFB_tmp)); + auto new_stride = make_stride(make_stride(stride<0,0>(mSFB_tmp), + make_stride(make_stride( x, x), x*3)), stride<1>(mSFB_tmp), stride<2>(mSFB_tmp)); + return make_tensor(mSFB_tmp.data(), make_layout(new_shape, new_stride)); + } + else { + return observed_tma_load_sfb_->get_tma_tensor(shape(params.layout_SFB)); + } + }(); + + // Partition for this CTA + Tensor gSFA_mkl = local_tile(mSFA_mkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (TILE_M,TILE_K,m,k,l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, MMA_SF_Tiler{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) + + Tensor tCgSFA_mkl = make_tensor(gSFA_mkl.data(), tiled_divide(gSFA_mkl.layout(), make_tile(get<0>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_M,MMA_K),Rest_MMA_M,Rest_MMA_K, m, k, l) + Tensor tCgSFB_nkl = make_tensor(gSFB_nkl.data(), tiled_divide(gSFB_nkl.layout(), make_tile(get<1>(MMA_SF_Tiler{}), get<2>(MMA_SF_Tiler{})))); // ((MMA_N,MMA_K),Rest_MMA_N,Rest_MMA_K, n, k, l) + + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + Layout cta_layout_mnk = make_layout(cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape())); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + int block_rank_in_cluster = cute::block_rank_in_cluster(); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(block_rank_in_cluster); + // Project the cta_layout for tma_a along the n-modes + auto [tAgSFA_mkl, tAsSFA] = tma_partition(*observed_tma_load_sfa_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsSFA), group_modes<0,3>(tCgSFA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgSFB_nkl, tBsSFB] = tma_partition(*observed_tma_load_sfb_, + get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), + group_modes<0,3>(tCsSFB), group_modes<0,3>(tCgSFB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + + return cute::make_tuple( + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_sfa, mcast_mask_sfb // multicast masks + ); + } + + /// Set up the data needed by this collective for mma compute. + CUTLASS_DEVICE auto + mma_init( + Params const& params, + TensorStorage& shared_tensors, + uint32_t const tmem_offset) const { + + // Allocate "fragments/descriptors" for A and B matrices + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // ((CTA_MMA_M,32),Rest_MMA_M,8,NUM_PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = make_tensor(sA);; + Tensor tCrB = make_tensor(sB);; + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE + + // + // Scale Factor + // + Tensor tCtSFA = make_tensor(take<0,3>(shape(SmemLayoutAtomSFA{}))); + // TMEM allocations for SFA and SFB will always start at DP 0. + tCtSFA.data() = tmem_offset; + Tensor tCtSFB = make_tensor(take<0,3>(shape(SmemLayoutAtomSFB{}))); + + tCtSFB.data() = tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tCtSFA); + + // Setup smem descriptors for UTCCP + Tensor tCsSFA = make_tensor(make_smem_ptr(shared_tensors.smem_SFA.begin()), SmemLayoutSFA{}); + Tensor tCsSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); + + // Make SMEM and TMEM tensors compact removing the zero strides to eliminate unnecessary copy instructions. + auto tCsSFA_compact = make_tensor(tCsSFA.data(), filter_zeros(tCsSFA.layout())); + auto tCtSFA_compact = make_tensor(tCtSFA.data(), filter_zeros(tCtSFA.layout())); + auto tCsSFB_compact = make_tensor(tCsSFB.data(), filter_zeros(tCsSFB.layout())); + auto tCtSFB_compact = make_tensor(tCtSFB.data(), filter_zeros(tCtSFB.layout())); + + // Create the SMEM to TMEM copy operations based on the MMA atom used (1CTA vs 2CTA) + using AtomThrID = typename TiledMma::AtomThrID; + using UtccpOp = cute::conditional_t<(decltype(cute::size(AtomThrID{}) == Int<2>{})::value), + SM100_UTCCP_4x32dp128bit_2cta, SM100_UTCCP_4x32dp128bit_1cta>; + auto tCtSFA_compact_copy = make_tensor(tCtSFA_compact.data(), append<3>(tCtSFA_compact(_,_0{},_0{}).layout())); + auto tCtSFB_compact_copy = make_tensor(tCtSFB_compact.data(), append<3>(tCtSFB_compact(_,_0{},_0{}).layout())); + auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact_copy); + auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact_copy); + + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); + auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); + + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); + // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor + auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); + auto thr_tCtSFB_compact_s2t = thr_copy_s2t_SFB.partition_D(tCtSFB_compact); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + tiled_mma.idesc_.a_format_ = uint8_t(params.runtime_data_type_a) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(params.runtime_data_type_b) & 0b111; + } + + // using MMA_SF_Tiler = decltype(make_tile(shape<0>(CtaShape_MNK{}), Int{}, Int(CtaShape_MNK{})/2>{})); // 128x128x384 + // MMA shapes are ((_128,_96),_1,_8) which makes the MMA_SFA_Shape ((128, (16,3)), 1, 8/3) + // The number is not divisible by 4 in K dimension which is needed for TMEM allocation. + // To be able to iterate thru the SFs for MMA, we model this as (MMA), MMA_M, MMA_K: ((128, (16,1)), 1, 24) + // with this layout we can iterate thru the SFs by incrementing MMA_K mode by 3/6 for this example (Vs=16 vs Vs=32). + constexpr int MMA_M = size<0>(CtaShape_MNK{}); + constexpr int MMA_N_SF = CTA_N_SF; + constexpr int MMA_K_SF = shape<2>(CtaShape_MNK{}) / 2; + auto mnBasicBlockShape = make_shape(_32{}, _4{}); + auto kBasicBlockShape_single = make_shape(Int{}, Int<1>{}); + auto mma_iter_SFA_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFA_iter_shape = make_shape(mma_iter_SFA_shape, _1{}, Int{}); + auto mma_iter_SFB_shape = make_shape( prepend(Int{}, mnBasicBlockShape), kBasicBlockShape_single); + auto sSFB_iter_shape = make_shape(mma_iter_SFB_shape, _1{}, Int{}); + + // Used for MMAs + using MmaIterShapeSFA = decltype(sSFA_iter_shape); // ((32,4),(SFVecSize,1), MMA_M/128, SF_MMA_K/SfVecSize + using MmaIterShapeSFB = decltype(sSFB_iter_shape); // ((32,4),(SFVecSize,1), MMA_N/128, SF_MMA_K/SfVecSize + + Tensor tCtSFA_mma = make_tensor(MmaIterShapeSFA{}); + tCtSFA_mma.data() = tCtSFA.data(); + Tensor tCtSFB_mma = make_tensor(MmaIterShapeSFB{}); + tCtSFB_mma.data() = tCtSFB.data(); + + return cute::make_tuple( + tiled_mma, + tCrA, tCrB, tCtSFA, tCtSFB, tCtSFA_mma, tCtSFB_mma, + tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t); + } + +// Helper function to handle both prefetch types + template + CUTLASS_DEVICE void issue_prefetch( + int& prefetch_k_tile_count, + int& prefetch_buf_idx, + KTileIterator& prefetch_k_tile, + TmaPrefetchFn&& tma_prefetch_fn + ) + { + if (prefetch_k_tile_count > 0) { + if constexpr (PrefetchType == cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch) { + tma_prefetch_fn(); + } + prefetch_buf_idx = (prefetch_buf_idx + 1) % BuffersPerKtile; + if(prefetch_buf_idx == 0) { + ++prefetch_k_tile; + --prefetch_k_tile_count; + } + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorA, class GTensorB, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_ab( + Params const& params, + MainloopABPipeline pipeline, + MainloopABPipelineState mainloop_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + int prefetch_k_tile_count = 0) { + + auto tAgA_mkl = get<2>(load_inputs); + auto tBgB_nkl = get<3>(load_inputs); + auto tAsA = get<4>(load_inputs); + auto tBsB = get<5>(load_inputs); + auto mcast_mask_a = get<6>(load_inputs); + auto mcast_mask_b = get<7>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, _, _, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + constexpr int BuffersPerKtile = 3; + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadABPipelineStageCount / BuffersPerKtile; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadABPipelineStageCount % BuffersPerKtile; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + using BarrierType = typename MainloopABPipeline::ProducerBarrierType; + // In total, we will load 3 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < BuffersPerKtile; buffer++) { + pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), group_modes<0,2>(tAgA(_,_,buffer,*k_tile_iter)), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), group_modes<0,2>(tBgB(_,_,buffer,*k_tile_iter)), tBsB(_,write_stage)); + } + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(*observed_tma_load_a_, group_modes<0,2>(tAgA(_,_,prefetch_buf_idx,*prefetch_k_tile))); + prefetch(*observed_tma_load_b_, group_modes<0,2>(tBgB(_,_,prefetch_buf_idx,*prefetch_k_tile))); + } + ); + } + } + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, k_tile_iter); + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class GTensorPartitionedSFA, class GTensorPartitionedSFB, + class STensorSFA, class STensorSFB, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load_sf( + Params const& params, + MainloopSFPipeline pipeline, + MainloopSFPipelineState mainloop_sf_pipe_producer_state, + cute::tuple const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count, + int prefetch_k_tile_count = 0) { + + auto tAgSFA_mkl = get<0>(load_inputs); + auto tBgSFB_nkl = get<1>(load_inputs); + auto tAsSFA = get<2>(load_inputs); + auto tBsSFB = get<3>(load_inputs); + auto mcast_mask_sfa = get<4>(load_inputs); + auto mcast_mask_sfb = get<5>(load_inputs); + // slice out the work coord from partitioned tensors + Tensor tAgSFA = tAgSFA_mkl(_, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor tBgSFB = tBgSFB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + auto barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + + using BarrierType = typename MainloopSFPipeline::ProducerBarrierType; + auto tAsSFA_compact = make_tensor(tAsSFA.data(), filter_zeros(tAsSFA.layout())); + auto tBsSFB_compact = make_tensor(tBsSFB.data(), filter_zeros(tBsSFB.layout())); + auto prefetch_k_tile = k_tile_iter; + auto prefetch_buf_idx = 0; + auto tile_k_advance = LoadSFPipelineStageCount / SF_BUFFERS_PER_TILE_K; + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + prefetch_buf_idx = LoadSFPipelineStageCount % SF_BUFFERS_PER_TILE_K; + CUTLASS_PRAGMA_UNROLL + for (int i=0;i 0) { + // In total, we will load 2 or 4 buffers per k_tile_iter. Unrolled. + CUTLASS_PRAGMA_UNROLL + for(int buffer = 0; buffer < SF_BUFFERS_PER_TILE_K; buffer++) { + pipeline.producer_acquire(mainloop_sf_pipe_producer_state, barrier_token); + BarrierType* tma_barrier = pipeline.producer_get_barrier(mainloop_sf_pipe_producer_state); + + int write_stage = mainloop_sf_pipe_producer_state.index(); + ++mainloop_sf_pipe_producer_state; + barrier_token = pipeline.producer_try_acquire(mainloop_sf_pipe_producer_state); + auto tAgSFA_compact = make_tensor(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tAgSFA(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + auto tBgSFB_compact = make_tensor(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).data(), filter_zeros(tBgSFB(_,*k_tile_iter*SF_BUFFERS_PER_TILE_K + buffer).layout())); + + if (cute::elect_one_sync()) { + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA_compact, tAsSFA_compact(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier, mcast_mask_sfb), tBgSFB_compact, tBsSFB_compact(_,write_stage)); + } + #if 0 + if(threadIdx.x == 256 && blockIdx.x == 1 && blockIdx.y == 0) { + print("tAgSFA_compact: "); print(tAgSFA_compact); print("\n"); + print("tBgSFB_compact: "); print(tBgSFB_compact); print("\n"); + } + #endif + + auto tAgSFA_compact_prefetch = make_tensor(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tAgSFA(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + auto tBgSFB_compact_prefetch = make_tensor(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).data(), filter_zeros(tBgSFB(_,*prefetch_k_tile*SF_BUFFERS_PER_TILE_K + prefetch_buf_idx).layout())); + + if constexpr (PrefetchType != cutlass::sm103::detail::KernelPrefetchType::Disable) { + issue_prefetch ( + prefetch_k_tile_count, + prefetch_buf_idx, + prefetch_k_tile, + [&]() { + prefetch(*observed_tma_load_sfa_, tAgSFA_compact_prefetch); + prefetch(*observed_tma_load_sfb_, tBgSFB_compact_prefetch); + } + ); + } + } + + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_sf_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + template < + class MainloopPipeline, class MainloopPipelineState + > + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, MainloopPipelineState mainloop_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + pipeline.producer_tail(mainloop_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class AccumulatorPipeline, + class FrgEngine, class FrgLayout, + class FragmentA, class FragmentB, + class FragmentSFA, class FragmentSFB, + class MmaFragmentSFA, class MmaFragmentSFB, + class CtaTileCoord, + class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, + class SFBTiledCopy, class SmemFrgSFB, class TmemFrgSFB + > + CUTLASS_DEVICE auto + mma(cute::tuple pipelines, + cute::tuple pipeline_states, + cute::Tensor& accumulators, + cute::tuple const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count + ) { + static_assert(is_tmem::value, "Accumulator must be tmem resident."); + static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto pipeline_ab = get<0>(pipelines); + auto pipeline_sf = get<1>(pipelines); + auto accumulator_pipeline = get<2>(pipelines); + auto mainloop_pipe_ab_consumer_state = get<0>(pipeline_states); + auto mainloop_pipe_sf_consumer_state = get<1>(pipeline_states); + auto accumulator_pipe_producer_state = get<2>(pipeline_states); + auto tiled_mma = get<0>(mma_inputs); + auto tCrA = get<1>(mma_inputs); + auto tCrB = get<2>(mma_inputs); + auto tCtSFA = get<3>(mma_inputs); + auto tCtSFB = get<4>(mma_inputs); + auto tCtSFA_mma = get<5>(mma_inputs); + auto tCtSFB_mma = get<6>(mma_inputs); + auto tiled_copy_s2t_SFA = get<7>(mma_inputs); + auto tCsSFA_s2t = get<8>(mma_inputs); + auto tCtSFA_s2t = get<9>(mma_inputs); + auto tiled_copy_s2t_SFB = get<10>(mma_inputs); + auto tCsSFB_s2t = get<11>(mma_inputs); + auto tCtSFB_s2t = get<12>(mma_inputs); + + tCtSFB_mma = [tCtSFB_mma = tCtSFB_mma, cta_tile_coord]() { + if constexpr (IsCtaN192) { + // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) + auto tCtSFB_tmp = tCtSFB_mma; + if (get<1>(cta_tile_coord) % 2 == 1) { + tCtSFB_tmp.data() = tCtSFB_tmp.data().get() + 2; + } + return tCtSFB_tmp; + } + else { + return tCtSFB_mma; + } + }(); + + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + constexpr int sf_stride = TiledMma::SFVecSize == 16 ? 6 : 3; + auto barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + auto barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state); + constexpr int MmasPerSfBuffer = 8 / SF_BUFFERS_PER_TILE_K; + + auto sf_load_fn = [&](const int kphase, const int k_tile_count) { + if (kphase % MmasPerSfBuffer == 0) { + pipeline_sf.consumer_wait(mainloop_pipe_sf_consumer_state, barrier_token_sf); + int read_stage_sf_buffer0 = mainloop_pipe_sf_consumer_state.index(); + if (cute::elect_one_sync()) { + copy(tiled_copy_s2t_SFA, tCsSFA_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFA_s2t); + copy(tiled_copy_s2t_SFB, tCsSFB_s2t(_,_,_,_,read_stage_sf_buffer0), tCtSFB_s2t); + } + auto buffer0_mainloop_pipe_sf_consumer_state = mainloop_pipe_sf_consumer_state; + ++mainloop_pipe_sf_consumer_state; + barrier_token_sf = pipeline_sf.consumer_try_wait(mainloop_pipe_sf_consumer_state, (kphase == 8 - MmasPerSfBuffer) && k_tile_count <= 1); // only skip wait for the last one. + pipeline_sf.consumer_release(buffer0_mainloop_pipe_sf_consumer_state); + } + }; + + bool is_first_iteration = true; + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // MMA 0 + sf_load_fn(0, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer0 = mainloop_pipe_ab_consumer_state.index(); + auto buffer0_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + // delay the acc acquire to unblock tmem copy. + if constexpr (IsOverlappingAccum) { + if(is_first_iteration) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + is_first_iteration = false; + } + }; + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,0,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,0,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 0 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + + // MMA 1 + sf_load_fn(1, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,3,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrA(_,_,0,read_stage_ab_buffer0), // Next A buffer for circular buffers: Points to buffer[0] + tCtSFA_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,3,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 48 bytes. Note the 3. + tCrB(_,_,0,read_stage_ab_buffer0), // Next B buffer for circular buffers: Points to buffer[0] + tCtSFB_mma(_, _, 1 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + + // MMA 2 + sf_load_fn(2, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer1 = mainloop_pipe_ab_consumer_state.index(); + auto buffer1_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,6,read_stage_ab_buffer0), // A buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,6,read_stage_ab_buffer0), // B buffer: Points to buffer[0] + 96 bytes. Note the 6. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 2 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer0_mainloop_pipe_ab_consumer_state); + + + // MMA 3 + sf_load_fn(3, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,1,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,1,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 16 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 3 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 4 + sf_load_fn(4, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,4,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrA(_,_,0,read_stage_ab_buffer1), // Next A buffer for circular buffers: Points to buffer[1]. + tCtSFA_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,4,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 64 bytes. Note the 1. + tCrB(_,_,0,read_stage_ab_buffer1), // Next B buffer for circular buffers: Points to buffer[1]. + tCtSFB_mma(_, _, 4 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + // MMA 5 + sf_load_fn(5, k_tile_count); + pipeline_ab.consumer_wait(mainloop_pipe_ab_consumer_state, barrier_token_ab); + int read_stage_ab_buffer2 = mainloop_pipe_ab_consumer_state.index(); + auto buffer2_mainloop_pipe_ab_consumer_state = mainloop_pipe_ab_consumer_state; + ++mainloop_pipe_ab_consumer_state; + barrier_token_ab = pipeline_ab.consumer_try_wait(mainloop_pipe_ab_consumer_state, k_tile_count <= 1); + + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,7,read_stage_ab_buffer1), // A buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,7,read_stage_ab_buffer1), // B buffer: Points to buffer[1] + 112 bytes. Note the 7. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 5 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer1_mainloop_pipe_ab_consumer_state); + + // MMA 6 + sf_load_fn(6, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,2,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,2,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 32 bytes. Note the 2. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 6 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + // MMA 7 + sf_load_fn(7, k_tile_count); + cute::gemm(tiled_mma, + make_zip_tensor(tCrA(_,_,5,read_stage_ab_buffer2), // A buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrA(_,_,0,read_stage_ab_buffer2), // Next A buffer for circular buffers: Points to buffer[2]. + tCtSFA_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFA + make_zip_tensor(tCrB(_,_,5,read_stage_ab_buffer2), // B buffer: Points to buffer[1] + 80 bytes. Note the 5. + tCrB(_,_,0,read_stage_ab_buffer2), // Next B buffer for circular buffers: Points to buffer[2]. + tCtSFB_mma(_, _, 7 % MmasPerSfBuffer * sf_stride)), // Tmem tensors for SFB + accumulators); // (V,M) x (V,N) => (V,M,N) + + pipeline_ab.consumer_release(buffer2_mainloop_pipe_ab_consumer_state); + --k_tile_count; + } + return cute::make_tuple(mainloop_pipe_ab_consumer_state, mainloop_pipe_sf_consumer_state); + } + +protected: + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + typename Params::TMA_SFA const* observed_tma_load_sfa_{nullptr}; + typename Params::TMA_SFB const* observed_tma_load_sfb_{nullptr}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 67c82688..7f956539 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -153,7 +153,15 @@ struct CollectiveMma< static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); - using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig; + static constexpr bool MMajorSFA = size<0,1>(InternalLayoutSFA{}.stride()) == 1; + static constexpr bool NMajorSFB = size<0,1>(InternalLayoutSFB{}.stride()) == 1; + + using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, + ScaleGranularityN, + ScaleGranularityK, + MMajorSFA ? cute::GMMA::Major::MN : cute::GMMA::Major::K, + NMajorSFB ? cute::GMMA::Major::MN : cute::GMMA::Major::K>; using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(TileShape{})); using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(TileShape{})); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index ecbd59b5..7def1f32 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -134,9 +134,12 @@ struct CollectiveMma< static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + static constexpr bool MMajorSFA = size<0,1>(LayoutSFA{}.stride()) == 1; + static constexpr bool NMajorSFB = size<0,1>(LayoutSFB{}.stride()) == 1; + static constexpr int ScaleTmaThreshold = 32; - static constexpr bool IsTmaLoadSFA = ScaleMsPerTile >= ScaleTmaThreshold && ScaleNsPerTile < ScaleTmaThreshold; - static constexpr bool IsTmaLoadSFB = ScaleNsPerTile >= ScaleTmaThreshold && ScaleMsPerTile < ScaleTmaThreshold; + static constexpr bool IsTmaLoadSFA = ScaleMsPerTile >= ScaleTmaThreshold && ScaleNsPerTile < ScaleTmaThreshold && MMajorSFA; + static constexpr bool IsTmaLoadSFB = ScaleNsPerTile >= ScaleTmaThreshold && ScaleMsPerTile < ScaleTmaThreshold && NMajorSFB; // Two threads per CTA are producers (1 for operand tile `tma`, and 32 for scales `cp.async`) static constexpr int NumProducerThreadEvents = ((IsTmaLoadSFA && IsTmaLoadSFB)? 1 : 33); @@ -151,7 +154,12 @@ struct CollectiveMma< static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); - using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig; + using ScaleConfig = ::cutlass::detail::Sm90BlockwiseScaleConfig< + ScaleGranularityM, + ScaleGranularityN, + ScaleGranularityK, + MMajorSFA ? cute::GMMA::Major::MN : cute::GMMA::Major::K, + NMajorSFB ? cute::GMMA::Major::MN : cute::GMMA::Major::K>; using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(TileShape{})); using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(TileShape{})); @@ -170,8 +178,8 @@ struct CollectiveMma< using CopyAtomSFA = Copy_Atom, ElementBlockScale>; using CopyAtomSFB = Copy_Atom, ElementBlockScale>; - static constexpr int AlignmentSFA = 1; - static constexpr int AlignmentSFB = 1; + static constexpr int AlignmentSFA = IsTmaLoadSFA ? 128 / cutlass::sizeof_bits::value : 1; + static constexpr int AlignmentSFB = IsTmaLoadSFB ? 128 / cutlass::sizeof_bits::value : 1; // Block scaling smem layout using SmemLayoutSFA = decltype(make_layout( @@ -669,7 +677,7 @@ struct CollectiveMma< Tensor tSFAcSFA_compact = filter_zeros(tSFAcSFA); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tSFApSFA); ++i) { - tSFApSFA(i) = load_sfa && elem_less(get<0>(tSFAcSFA_compact(i)), get<0>(SFA_shape)); + tSFApSFA(i) = load_sfa && elem_less(tSFAcSFA_compact(i), SFA_shape); } bool load_sfb = thread_idx < ScaleNsPerTile; @@ -677,7 +685,7 @@ struct CollectiveMma< Tensor tSFBcSFB_compact = filter_zeros(tSFBcSFB); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(tSFBpSFB); ++i) { - tSFBpSFB(i) = load_sfb && elem_less(get<0>(tSFBcSFB_compact(i)), get<0>(SFB_shape)); + tSFBpSFB(i) = load_sfb && elem_less(tSFBcSFB_compact(i), SFB_shape); } int write_stage = smem_pipe_write.index(); // Copy scale tensors from global memory to shared memory diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 3508de00..390e41f8 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -393,6 +393,7 @@ public: [[maybe_unused]] dim3 fallback_cluster = dim3{0,0,0}; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || GemmKernel::ArchTag::kMinComputeCapability == 101 + || GemmKernel::ArchTag::kMinComputeCapability == 103 ) { if constexpr (!cute::is_static_v) { fallback_cluster = params.hw_info.cluster_shape_fallback; @@ -473,6 +474,7 @@ public: if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || GemmKernel::ArchTag::kMinComputeCapability == 101 || GemmKernel::ArchTag::kMinComputeCapability == 120 + || GemmKernel::ArchTag::kMinComputeCapability == 103 ) { if constexpr (is_static_1x1x1) { #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 6f010c1b..5f836ecd 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -34,14 +34,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(limits) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/device_kernel.h" diff --git a/include/cutlass/gemm/device/gemv_blockscaled.h b/include/cutlass/gemm/device/gemv_blockscaled.h new file mode 100644 index 00000000..b4dc0dd3 --- /dev/null +++ b/include/cutlass/gemm/device/gemv_blockscaled.h @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm_universal.h" + +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/device/gemm_universal_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemvBlockScaled { +public: + + using GemvKernel = GemvKernel_; + + + using ElementA = typename GemvKernel::ElementA; + using LayoutA = typename GemvKernel::LayoutA; + using ElementB = typename GemvKernel::ElementB; + using ElementC = typename GemvKernel::ElementC; + + using ElementSFA = typename GemvKernel::ElementSFA; + using ElementSFB = typename GemvKernel::ElementSFB; + + using ElementAccumulator = typename GemvKernel::ElementAccumulator; + using EpilogueOutputOp = typename GemvKernel::EpilogueOutputOp; + + static ComplexTransform const kTransformA = GemvKernel::kTransformA; + static ComplexTransform const kTransformB = GemvKernel::kTransformB; + + static int const kThreadCount = GemvKernel::kThreadCount; + static int const kThreadsPerRow = GemvKernel::kThreadsPerRow; + + using Arguments = typename GemvKernel::Arguments; + using Params = typename GemvKernel::Params; + +private: + + Params params_; + +public: + + /// Constructs the GemvBlockScaled. + GemvBlockScaled() = default; + + /// Determines whether the GemvBlockScaled can execute the given problem. + static Status can_implement(Arguments const &args) { + + return GemvKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + return 0; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const &args, dim3 const &block) { + if(platform::is_same::value) { + return dim3((args.problem_size.row() + (block.x - 1)) / block.x, 1, args.batch_count % 65536); + } + else { + return dim3((args.problem_size.row() + (block.y - 1)) / block.y, 1, args.batch_count % 65536); + } + } + + /// Computes the block shape + static dim3 get_block_shape() { + if(platform::is_same::value) { + return dim3(kThreadCount, 1, 1); + } + else { + return dim3(kThreadsPerRow, kThreadCount / kThreadsPerRow, 1); + } + } + + /// Initializes GemvBlockScaled state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + params_ = Params(args); + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + return params_.update(args); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + const dim3 block = get_block_shape(); + const dim3 grid = get_grid_shape(params_, block); + + int smem_size = int(sizeof(typename GemvKernel::SharedStorage)); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + if (result == cudaSuccess) { + return Status::kSuccess; + } else { + return Status::kErrorInternal; + } + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index d051aa7a..314f99f5 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -36,6 +36,7 @@ #include "cute/layout.hpp" #include "cute/numeric/integral_constant.hpp" // cute::false_type #include "cute/atom/copy_traits_sm100.hpp" +#include "cutlass/detail/collective/sm103_kernel_type.hpp" ////////////////////////////////////////////////////////////////////////////// namespace cutlass::detail { @@ -423,6 +424,21 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling "KernelSchedule must be one of the warp specialized FP8 block scale policies"); }; +////////////////////////////////////////////////////////////////////////////// + +// +// Kernel Scheduler Tag +// + +// Dense GEMM: SM100 tensor op policy that applies to both 1SM and 2SM MMA atoms +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelWarpSpecializedSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; template< int SchedulerPipelineStageCount_, @@ -461,6 +477,24 @@ struct KernelPtrArrayTmaWarpSpecializedMmaTransformSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedBlockScaledSm103 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelPtrArrayTmaWarpSpecializedBlockScaledSm103 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; + // Sparse Gemm template< int SchedulerPipelineStageCount_, @@ -665,6 +699,8 @@ struct KernelScheduleSm100DenseGemm : KernelScheduleSm100 {}; // Base policy // Dense GEMM: Specialize for 1SM vs 2SM struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder +struct KernelWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder Without TMA + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Ptr-Array Dense GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -721,6 +757,8 @@ struct KernelScheduleSm100MixedInputGemm : KernelScheduleSm100 {}; struct KernelTmaWarpSpecializedMixedInputSmemSm100 : KernelScheduleSm100MixedInputGemm { }; struct KernelTmaWarpSpecialized1SmMixedInputSm100 final : KernelSchedule1Sm, KernelScheduleSm100MixedInputGemm { }; struct KernelTmaWarpSpecialized1SmMixedInputSmemSm100 final : KernelSchedule1Sm, KernelTmaWarpSpecializedMixedInputSmemSm100 { }; +struct KernelTmaWarpSpecialized2SmMixedInputSm100 final : KernelSchedule2Sm, KernelScheduleSm100MixedInputGemm { }; +struct KernelTmaWarpSpecialized2SmMixedInputSmemSm100 final : KernelSchedule2Sm, KernelTmaWarpSpecializedMixedInputSmemSm100 { }; /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Ptr-Array FastF32 (9xBF16) GEMM Dispatch Policies @@ -789,6 +827,54 @@ struct KernelSparseTmaWarpSpecialized2SmNvf4Sm100 final : KernelSchedule2 struct KernelSparseTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelScheduleSparseMxNvf4Sm100 { }; struct KernelSparseTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleSparseMxNvf4Sm100 { }; +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// +// SM103 Dispatch Policies +// +/////////////////////////////////////////////////////////////////////////////////////////////////////// + +struct KernelScheduleSm103 {}; +struct KernelScheduleSm103BlockScaledGemm : KernelScheduleSm103 {}; +struct KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch : KernelScheduleSm103BlockScaledGemm {}; +struct KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch : KernelScheduleSm103BlockScaledGemm {}; + +// Blockscaled Gemm: Specialized for instruction type, scale factor vector size, and 1SM vs. 2SM +// These are the public dispatch policy name +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraTmaPrefetch { }; + +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule1Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule2Sm, KernelScheduleSm103BlockScaledMxNvf4UltraDisablePrefetch { }; + +using KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 = KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch; +using KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 = KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch; +using KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 = KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch; +using KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 = KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch; + + +struct KernelSchedulePtrArraySm103BlockScaledGemm : KernelScheduleSm103 {}; +struct KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch : KernelSchedulePtrArraySm103BlockScaledGemm {}; +struct KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch : KernelSchedulePtrArraySm103BlockScaledGemm {}; + +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraTmaPrefetch { }; + +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule1Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; +struct KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch final : KernelSchedule2Sm, KernelSchedulePtrArraySm103BlockScaledMxNvf4UltraDisablePrefetch { }; + +using KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 = KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch; +using KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 = KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch; +using KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 = KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch; +using KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 = KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // // SM120 Dispatch Policies @@ -844,6 +930,25 @@ struct KernelSparseTmaWarpSpecializedMxf4Sm120 final : KernelScheduleS struct KernelSparseTmaWarpSpecializedMxf8f6f4Sm120 final : KernelScheduleSparseMxf8f6f4Sm120 { }; struct KernelSparseTmaWarpSpecializedMxf8f6f4Acc2x4Sm120 final : KernelScheduleSparseMxf8f6f4Sm120, KernelScheduleAcc2x4Sm120 { }; +////////////////////////////////////////////////////////////////////////////// + +// +// Collective Mainloop Dispatch Policies +// + +// n-buffer in smem, pipelined with Blackwell UMMA and CPASYNC, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100UmmaCpAsyncWarpSpecialized { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelWarpSpecializedSm100; +}; // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< @@ -983,6 +1088,34 @@ struct MainloopSm100TmaUmmaWarpSpecializedFastF32 { }; +// n-buffer in smem, pipelined with Blackwell Mixed Input kernel with UMMA (HwScaled) and TMA, +template< + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Load2TransformPipelineStageCount_, + // Number of Pipeline stages for + // MainloopLoad <-> Conversion <-> MainLoad + int Transform2MmaPipelineStageCount_, + // TileScheduler pipeline depth + int SchedulerPipelineStageCount_, + // Accmulator pipeline depth + int AccumulatorPipelineStageCount_, + // ClusterShape for the kernel + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100TmaUmmaWarpSpecializedMixedInput { + constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_; + constexpr static int Load2MmaPipelineStageCount = Load2TransformPipelineStageCount_; + constexpr static int Transform2MmaPipelineStageCount = Transform2MmaPipelineStageCount_; + constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::MixedInput; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelTmaWarpSpecializedMixedInputTransformSm100; + + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = Load2TransformPipelineStageCount; +}; + // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< @@ -1064,9 +1197,49 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32 { }; +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int LoadABPipelineStageCount_, + int LoadSFPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1>, + cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch +> +struct MainloopSm103TmaUmmaWarpSpecializedBlockScaled { + constexpr static int LoadABPipelineStageCount = LoadABPipelineStageCount_; + constexpr static int LoadSFPipelineStageCount = LoadSFPipelineStageCount_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm103; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelTmaWarpSpecializedBlockScaledSm103; + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = LoadABPipelineStageCount; + constexpr static cutlass::sm103::detail::KernelPrefetchType PrefetchType = PrefetchType_; +}; // Mainloop schedule for array-based TMA +template< + int LoadABPipelineStageCount_, + int LoadSFPipelineStageCount_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1>, + cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch +> +struct MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled { + constexpr static int LoadABPipelineStageCount = LoadABPipelineStageCount_; + constexpr static int LoadSFPipelineStageCount = LoadSFPipelineStageCount_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm103; + constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; + using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm103; + // For backwards compatibility with GemmUniversalAdapter. + constexpr static int Stages = LoadABPipelineStageCount; + constexpr static cutlass::sm103::detail::KernelPrefetchType PrefetchType = PrefetchType_; +}; + template< int Stages_, int SchedulerPipelineStageCount_, diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index e212a761..7b086e27 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -68,8 +68,13 @@ struct IsCutlass3ArrayKernel +struct GemvBlockScaled; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GEMV for row-major A matrix +template +struct GemvBlockScaled +{ +public: + using ElementA = ElementA_; + using ElementSFA = ElementSFA_; + using LayoutA = cutlass::layout::RowMajor; + using TensorRefA = cutlass::TensorRef; + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFA should be FP8 type"); + + using ElementB = ElementB_; + using ElementSFB = ElementSFB_; + using LayoutB = cutlass::layout::ColumnMajor; + static_assert(cutlass::sizeof_bits::value == 8, "ElementSFB should be FP8 type"); + + using ElementC = ElementC_; + using LayoutC = cutlass::layout::ColumnMajor; + + using ElementAccumulator = ElementAccumulator_; + + static constexpr cutlass::ComplexTransform kTransformA = cutlass::ComplexTransform::kNone; + static constexpr cutlass::ComplexTransform kTransformB = cutlass::ComplexTransform::kNone; + + static constexpr FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest; + + // number of return elements in a global access + static constexpr int kElementsPerAccess = kElementsPerAccess_; + static constexpr int kSFVecSize = kSFVecSize_; + static constexpr int kSFPerAccess = cutlass::const_max(1, kElementsPerAccess / kSFVecSize); + + static_assert(kSFVecSize == 16, "Only SFVecSize = 16 is supported"); + // Hardcode some check for easier debug + static_assert(kElementsPerAccess == 32, "for fp4 kernel, 32 elt per access"); + static_assert(kSFPerAccess == 2, "fpr fp4 kernel, 2 sf read per thread"); + + static constexpr bool kDequantizeA = cutlass::sizeof_bits::value == 4; + static constexpr bool kDequantizeB = cutlass::sizeof_bits::value == 4; + static constexpr int kPackedElementsA = cutlass::sizeof_bits::value == 4 ? 2 : 1; + static constexpr int kPackedElementsB = cutlass::sizeof_bits::value == 4 ? 2 : 1; + static constexpr int kPackedElements = cutlass::const_max(kPackedElementsA, kPackedElementsB); + + static_assert(kDequantizeA == true, "kDequantizeA should be true"); + static_assert(kDequantizeB == true, "kDequantizeB should be true"); + + using FragmentA = cutlass::Array; + using FragmentB = cutlass::Array; + using FragmentCompute = cutlass::Array; + using FragmentSFA = cutlass::Array; + using FragmentSFB = cutlass::Array; + using FragmentPackedA = cutlass::Array; + using FragmentPackedB = cutlass::Array; + + static_assert(sizeof_bits::value == 128, "FragmentA should be 128 bits"); + static_assert(sizeof_bits::value == 128, "FragmentB should be 128 bits"); + + // // thread block shape (kThreadsPerRow, kThreadCount / kThreadsPerRow, 1) + static constexpr int kThreadCount = (kThreadCount_ <= 0) ? 128 : kThreadCount_; + static constexpr int kThreadsPerRow = (kThreadsPerRow_ <= 0) ? + cutlass::const_min(static_cast(kThreadCount / cutlass::bits_to_bytes(kElementsPerAccess * cutlass::sizeof_bits::value)), 16) : + kThreadsPerRow_; + static constexpr int kThreadsPerCol = kThreadCount / kThreadsPerRow; + + static constexpr int kStageCount = 4; + static constexpr int kBufferCount = 2; + + // Number of elements stored in shared memory per stage for operands A and B. + // Each thread contributes `kElementsPerAccess / kPackedElements{A,B}` packed + // values. + static constexpr int kSmemPerStageA = kThreadCount * kElementsPerAccess / kPackedElementsA; + // B is uniform across all threads in the same k-column, so only store it once per k-thread + static constexpr int kSmemPerStageB = kThreadsPerRow * kElementsPerAccess / kPackedElementsB; + + using EpilogueOutputOp = EpilogueOutputOp_; + + // Ensure epilogue and mainloop have same thread layout + static_assert(kThreadCount == EpilogueOutputOp::kThreadCount, "mainloop, epilogue thread count mismatch"); + static_assert(kThreadsPerRow == EpilogueOutputOp::kThreadsPerRow, "mainloop, epilogue thread per row mismatch"); + static_assert(kThreadsPerCol == EpilogueOutputOp::kThreadsPerCol, "mainloop, epilogue thread per col mismatch"); + + // + // Structures + // + + /// Argument structure + struct Arguments + { + MatrixCoord problem_size; + int32_t batch_count{0}; + typename EpilogueOutputOp::Params epilogue; + + TensorRefA ref_A; + + ElementB const *ptr_B{nullptr}; + ElementC const *ptr_C{nullptr}; + ElementC *ptr_D{nullptr}; + + ElementSFA const *ptr_SFA{nullptr}; + ElementSFB const *ptr_SFB{nullptr}; + + int64_t stride_A{0}; + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + int64_t batch_stride_SFA{0}; + int64_t batch_stride_SFB{0}; + int64_t batch_stride_SFD{0}; + }; + + using Params = Arguments; + + /// Shared memory storage structure + struct SharedStorage + { + using EpilogueStorage = typename EpilogueOutputOp::SharedStorage; + EpilogueStorage epilogue; + + alignas(16) ElementA smem_A[kBufferCount][kStageCount][kSmemPerStageA]; + alignas(16) ElementB smem_B[kBufferCount][kStageCount][kSmemPerStageB]; + alignas(16) ElementSFA smem_SFA[kBufferCount][kStageCount][kThreadCount * kSFPerAccess]; + alignas(16) ElementSFB smem_SFB[kBufferCount][kStageCount][kThreadsPerRow * kSFPerAccess]; + }; + +public: + // + // Methods + // + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::MatrixCoord const &problem_size) + { + if (problem_size.column() % kElementsPerAccess != 0) { + return Status::kErrorMisalignedOperand; + } + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) + { + return can_implement(args.problem_size); + } + + /// Executes one GEMV + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) + { + EpilogueOutputOp epilogue(params.epilogue, shared_storage.epilogue); + + // Converters only needed for regular GEMV fallback case + NumericConverter A_converter; + NumericConverter B_converter; + NumericConverter SFA_converter; + NumericConverter SFB_converter; + + const int32_t gemm_m = params.problem_size.row(); + [[maybe_unused]] static constexpr int32_t gemm_n = 1; + const int32_t gemm_k = params.problem_size.column(); + const int32_t gemm_batch = params.batch_count; + + // Loop over batch indices + for (int batch_idx = blockIdx.z; batch_idx < gemm_batch; batch_idx += gridDim.z) { + + int idx_col_k = threadIdx.x; + int idx_row_m = blockIdx.x * blockDim.y + threadIdx.y; + + if (idx_row_m < gemm_m) { + // problem_size (row = m, column = k) + // matrix A (batch, m, k) + // vector B (batch, k, 1) + // vector C (batch, m, 1) + // vector D (batch, m, 1) + // move in the batch dimension + ElementA const *ptr_A = params.ref_A.data() + batch_idx * params.batch_stride_A / kPackedElementsA; + ElementB const *ptr_B = params.ptr_B + batch_idx * params.batch_stride_B / kPackedElementsB; + ElementC const *ptr_C = params.ptr_C + batch_idx * params.batch_stride_C; + ElementC *ptr_D = params.ptr_D + batch_idx * params.batch_stride_D; + + // move in the k dimension + ptr_A += idx_col_k * kElementsPerAccess / kPackedElementsA; + ptr_B += idx_col_k * kElementsPerAccess / kPackedElementsB; + + // move in the m dimension + ptr_A += idx_row_m * params.stride_A / kPackedElementsA; + ptr_C += idx_row_m; + ptr_D += idx_row_m; + + ElementSFA const *ptr_SF_A{nullptr}; + ElementSFB const *ptr_SF_B{nullptr}; + int global_k{0}; + + int SF_blocks_by_M = (gemm_m + 127) >> 7; + int SF_blocks_by_K = (gemm_k / kSFVecSize + 3) >> 2; + + // move in the batch dimension + ptr_SF_A = params.ptr_SFA + batch_idx * SF_blocks_by_M * SF_blocks_by_K * 512; + ptr_SF_B = params.ptr_SFB + batch_idx * SF_blocks_by_K * 512; + + // move in the m dimension + ptr_SF_A += (((idx_row_m >> 7) * SF_blocks_by_K) << 9) + ((idx_row_m & 0x1f) << 4) + ((idx_row_m & 0x7f) >> 5 << 2); + + global_k = idx_col_k * kElementsPerAccess; + + ElementAccumulator accum = ElementAccumulator(0); + + // Local aliases + const int tileA_k_local = kThreadsPerRow * kElementsPerAccess; + const int total_tiles = gemm_k / tileA_k_local; + + int unroll_col_k = 0; // total K elements consumed so far by this thread + const int thread_id = threadIdx.y * kThreadsPerRow + threadIdx.x; + const bool is_even_thread = (threadIdx.x % 2 == 0); + const bool load_b = (threadIdx.y == 0); + const int smem_sf_write_offset = (thread_id / 2) * 4; // 4 FP8 per even thread + const int smem_sf_offset = thread_id * kSFPerAccess; + + // Fast path: if the problem fits entirely in the tail path, skip SMEM + if (total_tiles == 0) { + accum += process_tail_elements(0, idx_col_k, gemm_k, + ptr_A, ptr_B, + ptr_SF_A, ptr_SF_B, + A_converter, B_converter, + SFA_converter, SFB_converter); + } else { + + // Scaling factors are now loaded from shared memory, no register pipeline needed + + // Thread-local SMEM line offset + const int thread_linear = threadIdx.y * kThreadsPerRow + threadIdx.x; + const int smem_offset_A = thread_linear * (kElementsPerAccess / kPackedElementsA); + // Only one row of threads (threadIdx.y == 0) loads B + const int smem_offset_B = threadIdx.x * (kElementsPerAccess / kPackedElementsB); + + // PROLOGUE – prime first kStageCount-1 stages into buffer 0 + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBufferCount - 1; ++b) { + // Load all stages using the helper function + load_stages_gmem_to_smem( + b, // buffer_idx + kStageCount, // num_stages + unroll_col_k, // passed by reference + global_k, // passed by reference + tileA_k_local, + smem_offset_A, + smem_offset_B, + smem_sf_write_offset, + is_even_thread, + load_b, + true, // valid_tile = true for prologue + ptr_A, + ptr_B, + ptr_SF_A, + ptr_SF_B, + shared_storage); + } + cutlass::arch::cp_async_fence(); + + // Ensure first stage committed + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Register double buffering for A/B fragments and SFA/SFB like SM80 + FragmentA fragA_reg[2]; + FragmentB fragB_reg[2]; + FragmentSFA fragSFA_reg[2]; + FragmentSFB fragSFB_reg[2]; + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = kBufferCount - 1; + + // PREFETCH register pipeline - load first kblock (stage 0) into register bank 0 + if constexpr (kStageCount > 1) + { + int frag_idx = 0; + + // Load fragments using the helper function + load_smem_fragments( + fragA_reg[frag_idx], + fragB_reg[frag_idx], + fragSFA_reg[frag_idx], + fragSFB_reg[frag_idx], + smem_pipe_read, + 0, // k_block = 0 + smem_offset_A, + smem_offset_B, + smem_sf_offset, + shared_storage); + + } + + // Mainloop + int tile_idx = 0; + while (tile_idx < total_tiles) { + int smem_pipe_read_curr = smem_pipe_read; + + for_each(make_int_sequence{}, [&] (auto k_block) + { + if (k_block == kStageCount - 1) + { + cutlass::arch::cp_async_wait(); + __syncthreads(); + + smem_pipe_read_curr = smem_pipe_read; + } + + // Load A/B/SFA/SFB smem->regs for k_block_next + auto k_block_next = (k_block + Int<1>{}) % kStageCount; + int frag_idx_next = (k_block + 1) & 1; + + // Prefetch next kblock data using saved pipe index + load_smem_fragments( + fragA_reg[frag_idx_next], + fragB_reg[frag_idx_next], + fragSFA_reg[frag_idx_next], + fragSFB_reg[frag_idx_next], + smem_pipe_read_curr, + k_block_next, + smem_offset_A, + smem_offset_B, + smem_sf_offset, + shared_storage); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // Use predicate instead of branch for cp_async + bool valid_tile = (global_k < gemm_k); + + // Load all stages using the helper function + load_stages_gmem_to_smem( + smem_pipe_write, // buffer_idx + kStageCount, // num_stages + unroll_col_k, // passed by reference + global_k, // passed by reference + tileA_k_local, + smem_offset_A, + smem_offset_B, + smem_sf_write_offset, + is_even_thread, + load_b, + valid_tile, + ptr_A, + ptr_B, + ptr_SF_A, + ptr_SF_B, + shared_storage); + + cutlass::arch::cp_async_fence(); + + // Advance the pipe indices + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == kBufferCount) ? 0 : smem_pipe_read; + } + + { + int frag_idx = k_block & 1; + + // Compute using current fragments + accum += blockscaled_multiply_add( + fragA_reg[frag_idx], fragB_reg[frag_idx], + fragSFA_reg[frag_idx], + fragSFB_reg[frag_idx]); + } + }); + + tile_idx += kStageCount; + } + + // Drain outstanding async copies + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + // Tail elements that don't fill a full tile + if (unroll_col_k + idx_col_k * kPackedElementsA < gemm_k) { + accum += process_tail_elements(unroll_col_k, idx_col_k, gemm_k, + ptr_A, ptr_B, + ptr_SF_A, ptr_SF_B, + A_converter, B_converter, + SFA_converter, SFB_converter); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int mask = (kThreadsPerRow >> 1); mask > 0; mask >>= 1) { + accum += ElementAccumulator(__shfl_xor_sync(0xFFFFFFFF, static_cast(accum), mask, 32)); + } + + auto frag_acc = static_cast(accum); + auto frag_c = static_cast(*(ptr_C)); + + // Applying blockscaled epilogue + epilogue(frag_acc, frag_c, batch_idx); + } + } + } //end of operator() + +private: + // Load multiple stages from global to shared memory + CUTLASS_DEVICE + void load_stages_gmem_to_smem( + int buffer_idx, + int num_stages, + int& unroll_col_k, + int& global_k, + int tileA_k_local, + int smem_offset_A, + int smem_offset_B, + int smem_sf_write_offset, + bool is_even_thread, + bool load_b, + bool valid_tile, + ElementA const* ptr_A, + ElementB const* ptr_B, + ElementSFA const* ptr_SF_A, + ElementSFB const* ptr_SF_B, + SharedStorage& shared_storage) { + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < num_stages; ++s) { + // Load scaling factors using cp.async - only even threads participate + // Calculate SF indices for this thread + int SF_idx = global_k / kSFVecSize; + int SF_offset_by_k = ((SF_idx >> 2) << 9) + (SF_idx & 0x3); + + void *smem_ptr_SFA = &shared_storage.smem_SFA[buffer_idx][s][smem_sf_write_offset]; + const void *gmem_ptr_SFA = ptr_SF_A + SF_offset_by_k; + // Load 4 FP8 values (32 bits) - for this thread and next thread + cutlass::arch::cp_async(smem_ptr_SFA, gmem_ptr_SFA, valid_tile && is_even_thread); + + void *smem_ptr_SFB = &shared_storage.smem_SFB[buffer_idx][s][(threadIdx.x / 2) * 4]; + const void *gmem_ptr_SFB = ptr_SF_B + SF_offset_by_k; + // Load 4 FP8 values (32 bits) - for this thread and next thread, only if threadIdx.y == 0 + cutlass::arch::cp_async(smem_ptr_SFB, gmem_ptr_SFB, valid_tile && load_b && is_even_thread); + + void *smem_ptr_A = &shared_storage.smem_A[buffer_idx][s][smem_offset_A]; + const void *gmem_ptr_A = ptr_A + unroll_col_k / kPackedElementsA; + cutlass::arch::cp_async(smem_ptr_A, gmem_ptr_A, valid_tile); + + void *smem_ptr_B = &shared_storage.smem_B[buffer_idx][s][smem_offset_B]; + const void *gmem_ptr_B = ptr_B + unroll_col_k / kPackedElementsB; + cutlass::arch::cp_async(smem_ptr_B, gmem_ptr_B, valid_tile && load_b); + + unroll_col_k += tileA_k_local; + global_k += tileA_k_local; + } + } + + /// Fused blockscaled GEMV computation using PTX + CUTLASS_DEVICE + ElementAccumulator blockscaled_multiply_add( + FragmentA const& fragA, + FragmentB const& fragB, + FragmentSFA const& fragSFA, + FragmentSFB const& fragSFB) { + + #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) + uint16_t const& src_fragSFA_packed = reinterpret_cast(fragSFA); + uint16_t const& src_fragSFB_packed = reinterpret_cast(fragSFB); + + uint32_t const* src_fragA_packed = reinterpret_cast(&fragA); + uint32_t const* src_fragB_packed = reinterpret_cast(&fragB); + + ElementAccumulator out; + uint16_t* out_fp16 = reinterpret_cast(&out); + + asm volatile( \ + "{\n" \ + // declare registers for A / B tensors + ".reg .b8 byte0_0, byte0_1, byte0_2, byte0_3;\n" \ + ".reg .b8 byte0_4, byte0_5, byte0_6, byte0_7;\n" \ + ".reg .b8 byte1_0, byte1_1, byte1_2, byte1_3;\n" \ + ".reg .b8 byte1_4, byte1_5, byte1_6, byte1_7;\n" \ + ".reg .b8 byte2_0, byte2_1, byte2_2, byte2_3;\n" \ + ".reg .b8 byte2_4, byte2_5, byte2_6, byte2_7;\n" \ + ".reg .b8 byte3_0, byte3_1, byte3_2, byte3_3;\n" \ + ".reg .b8 byte3_4, byte3_5, byte3_6, byte3_7;\n" \ + + // declare registers for accumulators + ".reg .f16x2 accum_0_0, accum_0_1, accum_0_2, accum_0_3;\n" \ + ".reg .f16x2 accum_1_0, accum_1_1, accum_1_2, accum_1_3;\n" \ + ".reg .f16x2 accum_2_0, accum_2_1, accum_2_2, accum_2_3;\n" \ + ".reg .f16x2 accum_3_0, accum_3_1, accum_3_2, accum_3_3;\n" \ + + // declare registers for scaling factors + ".reg .f16x2 sfa_f16x2;\n" \ + ".reg .f16x2 sfb_f16x2;\n" \ + ".reg .f16x2 sf_f16x2;\n" \ + + // declare registers for conversion + ".reg .f16x2 cvt_0_0, cvt_0_1, cvt_0_2, cvt_0_3;\n" \ + ".reg .f16x2 cvt_0_4, cvt_0_5, cvt_0_6, cvt_0_7;\n" \ + ".reg .f16x2 cvt_1_0, cvt_1_1, cvt_1_2, cvt_1_3;\n" \ + ".reg .f16x2 cvt_1_4, cvt_1_5, cvt_1_6, cvt_1_7;\n" \ + ".reg .f16x2 cvt_2_0, cvt_2_1, cvt_2_2, cvt_2_3;\n" \ + ".reg .f16x2 cvt_2_4, cvt_2_5, cvt_2_6, cvt_2_7;\n" \ + ".reg .f16x2 cvt_3_0, cvt_3_1, cvt_3_2, cvt_3_3;\n" \ + ".reg .f16x2 cvt_3_4, cvt_3_5, cvt_3_6, cvt_3_7;\n" \ + ".reg .f16 result_f16, lane0, lane1;\n" \ + ".reg .f16x2 mul_f16x2_0, mul_f16x2_1;\n" \ + + // convert scaling factors from fp8 to f16x2 + "cvt.rn.f16x2.e4m3x2 sfa_f16x2, %1;\n" \ + "cvt.rn.f16x2.e4m3x2 sfb_f16x2, %2;\n" \ + + // clear accumulators + "mov.b32 accum_0_0, 0;\n" \ + "mov.b32 accum_0_1, 0;\n" \ + "mov.b32 accum_0_2, 0;\n" \ + "mov.b32 accum_0_3, 0;\n" \ + "mov.b32 accum_1_0, 0;\n" \ + "mov.b32 accum_1_1, 0;\n" \ + "mov.b32 accum_1_2, 0;\n" \ + "mov.b32 accum_1_3, 0;\n" \ + "mov.b32 accum_2_0, 0;\n" \ + "mov.b32 accum_2_1, 0;\n" \ + "mov.b32 accum_2_2, 0;\n" \ + "mov.b32 accum_2_3, 0;\n" \ + "mov.b32 accum_3_0, 0;\n" \ + "mov.b32 accum_3_1, 0;\n" \ + "mov.b32 accum_3_2, 0;\n" \ + "mov.b32 accum_3_3, 0;\n" \ + + // multiply, unpacking and permuting scale factors + "mul.rn.f16x2 sf_f16x2, sfa_f16x2, sfb_f16x2;\n" \ + "mov.b32 {lane0, lane1}, sf_f16x2;\n" \ + "mov.b32 mul_f16x2_0, {lane0, lane0};\n" \ + "mov.b32 mul_f16x2_1, {lane1, lane1};\n" \ + + // unpacking A and B tensors + "mov.b32 {byte0_0, byte0_1, byte0_2, byte0_3}, %3;\n" \ + "mov.b32 {byte0_4, byte0_5, byte0_6, byte0_7}, %4;\n" \ + "mov.b32 {byte1_0, byte1_1, byte1_2, byte1_3}, %5;\n" \ + "mov.b32 {byte1_4, byte1_5, byte1_6, byte1_7}, %6;\n" \ + "mov.b32 {byte2_0, byte2_1, byte2_2, byte2_3}, %7;\n" \ + "mov.b32 {byte2_4, byte2_5, byte2_6, byte2_7}, %8;\n" \ + "mov.b32 {byte3_0, byte3_1, byte3_2, byte3_3}, %9;\n" \ + "mov.b32 {byte3_4, byte3_5, byte3_6, byte3_7}, %10;\n" \ + + // convert A and B tensors from fp4 to f16x2 + + // A[0 - 7] and B[0 - 7] + "cvt.rn.f16x2.e2m1x2 cvt_0_0, byte0_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_1, byte0_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_2, byte0_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_3, byte0_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_4, byte0_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_5, byte0_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_6, byte0_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_0_7, byte0_7;\n" \ + + // A[8 - 15] and B[8 - 15] + "cvt.rn.f16x2.e2m1x2 cvt_1_0, byte1_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_1, byte1_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_2, byte1_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_3, byte1_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_4, byte1_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_5, byte1_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_6, byte1_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_1_7, byte1_7;\n" \ + + // A[16 - 23] and B[16 - 23] + "cvt.rn.f16x2.e2m1x2 cvt_2_0, byte2_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_1, byte2_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_2, byte2_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_3, byte2_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_4, byte2_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_5, byte2_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_6, byte2_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_2_7, byte2_7;\n" \ + + // A[24 - 31] and B[24 - 31] + "cvt.rn.f16x2.e2m1x2 cvt_3_0, byte3_0;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_1, byte3_1;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_2, byte3_2;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_3, byte3_3;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_4, byte3_4;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_5, byte3_5;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_6, byte3_6;\n" \ + "cvt.rn.f16x2.e2m1x2 cvt_3_7, byte3_7;\n" \ + + // fma for A[0 - 7] and B[0 - 7] + "fma.rn.f16x2 accum_0_0, cvt_0_0, cvt_0_4, accum_0_0;\n" \ + "fma.rn.f16x2 accum_0_1, cvt_0_1, cvt_0_5, accum_0_1;\n" \ + "fma.rn.f16x2 accum_0_2, cvt_0_2, cvt_0_6, accum_0_2;\n" \ + "fma.rn.f16x2 accum_0_3, cvt_0_3, cvt_0_7, accum_0_3;\n" \ + + // fma for A[8 - 15] and B[8 - 15] + "fma.rn.f16x2 accum_1_0, cvt_1_0, cvt_1_4, accum_1_0;\n" \ + "fma.rn.f16x2 accum_1_1, cvt_1_1, cvt_1_5, accum_1_1;\n" \ + "fma.rn.f16x2 accum_1_2, cvt_1_2, cvt_1_6, accum_1_2;\n" \ + "fma.rn.f16x2 accum_1_3, cvt_1_3, cvt_1_7, accum_1_3;\n" \ + + // fma for A[16 - 23] and B[16 - 23] + "fma.rn.f16x2 accum_2_0, cvt_2_0, cvt_2_4, accum_2_0;\n" \ + "fma.rn.f16x2 accum_2_1, cvt_2_1, cvt_2_5, accum_2_1;\n" \ + "fma.rn.f16x2 accum_2_2, cvt_2_2, cvt_2_6, accum_2_2;\n" \ + "fma.rn.f16x2 accum_2_3, cvt_2_3, cvt_2_7, accum_2_3;\n" \ + + // fma for A[24 - 31] and B[24 - 31] + "fma.rn.f16x2 accum_3_0, cvt_3_0, cvt_3_4, accum_3_0;\n" \ + "fma.rn.f16x2 accum_3_1, cvt_3_1, cvt_3_5, accum_3_1;\n" \ + "fma.rn.f16x2 accum_3_2, cvt_3_2, cvt_3_6, accum_3_2;\n" \ + "fma.rn.f16x2 accum_3_3, cvt_3_3, cvt_3_7, accum_3_3;\n" \ + + // tree reduction for accumulators + "add.rn.f16x2 accum_0_0, accum_0_0, accum_0_1;\n" \ + "add.rn.f16x2 accum_0_2, accum_0_2, accum_0_3;\n" \ + "add.rn.f16x2 accum_1_0, accum_1_0, accum_1_1;\n" \ + "add.rn.f16x2 accum_1_2, accum_1_2, accum_1_3;\n" \ + "add.rn.f16x2 accum_2_0, accum_2_0, accum_2_1;\n" \ + "add.rn.f16x2 accum_2_2, accum_2_2, accum_2_3;\n" \ + "add.rn.f16x2 accum_3_0, accum_3_0, accum_3_1;\n" \ + "add.rn.f16x2 accum_3_2, accum_3_2, accum_3_3;\n" \ + + "add.rn.f16x2 accum_0_0, accum_0_0, accum_0_2;\n" \ + "add.rn.f16x2 accum_1_0, accum_1_0, accum_1_2;\n" \ + "add.rn.f16x2 accum_2_0, accum_2_0, accum_2_2;\n" \ + "add.rn.f16x2 accum_3_0, accum_3_0, accum_3_2;\n" \ + + "add.rn.f16x2 accum_0_0, accum_0_0, accum_1_0;\n" \ + "add.rn.f16x2 accum_2_0, accum_2_0, accum_3_0;\n" \ + + // apply scaling factors and final reduction + "mul.rn.f16x2 accum_0_0, mul_f16x2_0, accum_0_0;\n" \ + "mul.rn.f16x2 accum_2_0, mul_f16x2_1, accum_2_0;\n" \ + + "add.rn.f16x2 accum_0_0, accum_0_0, accum_2_0;\n" \ + + "mov.b32 {lane0, lane1}, accum_0_0;\n" \ + "add.rn.f16 result_f16, lane0, lane1;\n" \ + + "mov.b16 %0, result_f16;\n" \ + + "}\n" + : "=h"(out_fp16[0]) // 0 + : "h"(src_fragSFA_packed), "h"(src_fragSFB_packed), // 1, 2 + "r"(src_fragA_packed[0]), "r"(src_fragB_packed[0]), // 3, 4 + "r"(src_fragA_packed[1]), "r"(src_fragB_packed[1]), // 5, 6 + "r"(src_fragA_packed[2]), "r"(src_fragB_packed[2]), // 7, 8 + "r"(src_fragA_packed[3]), "r"(src_fragB_packed[3]) // 9, 10 + : "memory" + ); + + return out; + + #else + NumericArrayConverter srcA_converter; + NumericArrayConverter srcB_converter; + NumericConverter SFA_converter; + NumericConverter SFB_converter; + + FragmentCompute fragA_Compute = srcA_converter(fragA); + FragmentCompute fragB_Compute = srcB_converter(fragB); + ElementAccumulator accum = ElementAccumulator(0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kSFPerAccess; i++) { + ElementAccumulator accum_SF_block = ElementAccumulator(0); + + int local_k_offset = i * kSFVecSize; + ElementAccumulator multiplier{1}; + + multiplier = SFA_converter(fragSFA.at(i)) * SFB_converter(fragSFB.at(i)); + + + CUTLASS_PRAGMA_UNROLL + for (int e = 0; e < kSFVecSize; e++) { + accum_SF_block += fragA_Compute.at(e + local_k_offset) * fragB_Compute.at(e + local_k_offset); + } + + accum_SF_block *= multiplier; + accum += accum_SF_block; + } + + return accum; + + #endif + } + + CUTLASS_DEVICE + ElementAccumulator process_tail_elements( + int unroll_col_k, + int idx_col_k, + int gemm_k, + ElementA const *ptr_A, + ElementB const *ptr_B, + ElementSFA const *ptr_SF_A, + ElementSFB const *ptr_SF_B, + NumericConverter const &A_converter, + NumericConverter const &B_converter, + NumericConverter const &SFA_converter, + NumericConverter const &SFB_converter) { + + ElementAccumulator accum = ElementAccumulator(0); + + // calculate the rest of K elements + // each thread fetch 1 element each time + for (int k = unroll_col_k + idx_col_k * kPackedElementsA; k < gemm_k; k += kThreadsPerRow * kPackedElementsA) { + // blockscaled GEMV + int SF_idx = k / kSFVecSize; + int SF_offset_by_k = ((SF_idx >> 2) << 9) + (SF_idx & 0x3); + + ElementSFA sfa = *(ptr_SF_A + SF_offset_by_k); + ElementSFB sfb = *(ptr_SF_B + SF_offset_by_k); + + FragmentPackedA fragA; + FragmentPackedB fragB; + + // fetch from matrix A + arch::global_load( + fragA, + ptr_A - (idx_col_k * kElementsPerAccess - k) / kPackedElementsA, + true); + + // fetch from vector B + arch::global_load( + fragB, + ptr_B - (idx_col_k * kElementsPerAccess - k) / kPackedElementsB, + true); + + ElementAccumulator accum_SF_packed = ElementAccumulator(0); + + CUTLASS_PRAGMA_UNROLL + for (int e = 0; e < kPackedElements; e++) { + accum_SF_packed += A_converter(fragA.at(e)) * B_converter(fragB.at(e)); + } + + accum_SF_packed *= SFA_converter(sfa) * SFB_converter(sfb); + + accum += accum_SF_packed; + + } + + return accum; + } + + // Load fragments from shared memory + template + CUTLASS_DEVICE + void load_smem_fragments( + FragmentA& fragA, + FragmentB& fragB, + FragmentSFA& fragSFA, + FragmentSFB& fragSFB, + int smem_pipe_idx, + int k_block, + int smem_offset_A, + int smem_offset_B, + int smem_sf_offset, + SharedStorage& shared_storage) const { + + // Load A/B fragments + arch::shared_load(fragA, &shared_storage.smem_A[smem_pipe_idx][k_block][smem_offset_A]); + arch::shared_load(fragB, &shared_storage.smem_B[smem_pipe_idx][k_block][smem_offset_B]); + + // Load SF fragments + uint32_t smem_ptr = cutlass::arch::cutlass_get_smem_pointer(&shared_storage.smem_SFA[smem_pipe_idx][k_block][smem_sf_offset]); + arch::shared_load<2>(&fragSFA, smem_ptr); + smem_ptr = cutlass::arch::cutlass_get_smem_pointer(&shared_storage.smem_SFB[smem_pipe_idx][k_block][threadIdx.x * kSFPerAccess]); + arch::shared_load<2>(&fragSFB, smem_ptr); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp new file mode 100644 index 00000000..21ff5959 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp @@ -0,0 +1,793 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = false; + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment kernel only supports 1x1x1 cluster shape."); + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = CollectiveMainloop::NumLoadThreads; // 4 warps + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipelines and pipeline states + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + + // Pipeline and pipeline state types + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + static constexpr int EpilogueWarpRegs = 248; + static constexpr int NonEpilogueWarpRegs = 128; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + EpilogueLoad = 3, + Epilogue = 4, + MainloopLoad = 8 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_load = false; + }; + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + hw_info, + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + static constexpr int MaxClusterSize = 16; + implementable &= size(ClusterShape{}) <= MaxClusterSize; + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + status = cutlass::Status::kSuccess; + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto blk_shape = CtaShape_MNK{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info + ); + + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + +public: + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::MainloopLoad) ? WarpCategory::Epilogue + : WarpCategory::MainloopLoad; + uint32_t lane_predicate = cute::elect_one_sync(); + auto tile_shape = TileShape{}; + auto cluster_shape = ClusterShape{}; + constexpr int cluster_size = size(ClusterShape{}); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + int mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + [[maybe_unused]] uint32_t mma_peer_cta_rank = cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA) && is_mma_leader_cta, // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopLoad) // main_load + }; + + // Mainloop Load pipeline + typename MainloopPipeline::Params mainloop_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + + mainloop_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + mainloop_pipeline_params.consumer_arv_count = 1; // Only UMMA consumes the A and B buffers + mainloop_pipeline_params.dst_blockid = cta_rank_in_cluster; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, cluster_shape); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 3; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + Int{})); + + // + // END PROLOGUE + // + + // Synchronization call. Blocks until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + cutlass::arch::warpgroup_reg_dealloc(); + + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(load_inputs); + + do { + // Get current work tile and fetch next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + auto [mainloop_producer_state_next, unused_] = collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + mainloop_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_count + ); + mainloop_pipe_producer_state = mainloop_producer_state_next; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + } + + else if (is_participant.sched) { + cutlass::arch::warpgroup_reg_dealloc(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.mma) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem.data() = tmem_base_ptr; + + // Pass the acc with tuple type since the bgrad kernel change the mma_init API + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, cute::make_tuple(bulk_tmem, bulk_tmem), shared_storage.tensors.mainloop); + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + + int acc_stage = accumulator_pipe_producer_state.index(); + Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + mainloop_pipe_consumer_state = collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + // Pass the acc with tuple type since the bgrad kernel change the mma API + cute::make_tuple(accumulators, accumulators), + mma_inputs, + k_tile_count + ); + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + + ++accumulator_pipe_producer_state; + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + cutlass::arch::warpgroup_reg_dealloc(); + + bool do_tail_load = false; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + cutlass::arch::warpgroup_reg_alloc(); + + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem.data() = tmem_base_ptr; + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // Accumulator stage slice + int acc_stage = accumulator_pipe_consumer_state.index(); + Tensor accumulators = bulk_tmem(_,_,_,acc_stage); + + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulators, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulators, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + else { + cutlass::arch::warpgroup_reg_dealloc(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp new file mode 100644 index 00000000..55c18c9a --- /dev/null +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp @@ -0,0 +1,1090 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + + // Get Blk and Scheduling tile shapes + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static constexpr bool IsComplex = DispatchPolicy::InputTransformType == cutlass::gemm::detail::KernelInputTransformType::InterleavedComplexTF32; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + // TileID scheduler + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveMainloop::NumAccumThreads; // 4 warps + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumTransformationThreads = CollectiveMainloop::NumTransformationThreads; // 4 warps + static constexpr uint32_t NumMainloopLoadBThreads = NumThreadsPerWarp; // 1 warp + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + + NumEpilogueThreads + NumTransformationThreads + NumMainloopLoadBThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr cutlass::gemm::detail::KernelInputTransformType InputTransformType = DispatchPolicy::InputTransformType; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + // Pipeline and pipeline state types + using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; + using Load2TransformPipelineState = typename CollectiveMainloop::Load2TransformPipelineState; + + using Load2MmaPipeline = typename CollectiveMainloop::Load2MmaPipeline; + using Load2MmaPipelineState = typename CollectiveMainloop::Load2MmaPipelineState; + + using Transform2MmaPipeline = typename CollectiveMainloop::Transform2MmaPipeline; + using Transform2MmaPipelineState = typename CollectiveMainloop::Transform2MmaPipelineState; + + using Mma2AccumPipeline = typename CollectiveMainloop::Mma2AccumPipeline; + using Mma2AccumPipelineState = typename CollectiveMainloop::Mma2AccumPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = cutlass::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + // Kernel level shared memory storage + struct SharedStorage { + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4, + // Transformation starts at 256 thread alignment + Transformation = 8, + MainloopLoadB = 12, + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t main_loadA = false; + uint32_t main_loadB = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t transformation = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + static constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + static constexpr uint32_t NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + auto blk_shape = CtaShape_MNK{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for multiple epilogue and transformation warps + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : warp_idx < static_cast(WarpCategory::Transformation) ? WarpCategory::Epilogue + : warp_idx < static_cast(WarpCategory::MainloopLoadB) ? WarpCategory::Transformation + : WarpCategory::MainloopLoadB; + + int thread_idx = int(threadIdx.x); + int thread_idx_in_warp = thread_idx % 32; + uint32_t lane_predicate = cute::elect_one_sync(); + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + bool is_first_cta_in_cluster = (cta_rank_in_cluster == 0); + bool is_mma_leader_cta = (cta_rank_in_cluster % size<0>(TiledMma{}) == 0); + // Even if this variable is unused, shape_div still performs useful compile-time checks. + [[maybe_unused]] auto mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && (is_first_cta_in_cluster), // sched + (warp_category == WarpCategory::MainloopLoad || warp_category == WarpCategory::MainloopLoadB), // main_load + (warp_category == WarpCategory::MainloopLoad), // main_loadA + (warp_category == WarpCategory::MainloopLoadB), // main_loadB + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::Transformation) // transformation + }; + + // MainloopLoad <--> Transformation Pipeline + typename Load2TransformPipeline::Params load2transform_pipeline_params; + if (warp_category == WarpCategory::MainloopLoad) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Transformation) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Consumer; + } + load2transform_pipeline_params.is_leader = (thread_idx_in_warp == 0); + load2transform_pipeline_params.num_consumers = NumTransformationThreads; + load2transform_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes_A; + load2transform_pipeline_params.initializing_warp = 0; + Load2TransformPipeline load2transform_pipeline(shared_storage.pipelines.mainloop.load2transform_pipeline, + load2transform_pipeline_params, + cluster_shape, + McastDirection::kRow, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Load2TransformPipelineState load2transform_pipeline_consumer_state; + Load2TransformPipelineState load2transform_pipeline_producer_state = cutlass::make_producer_start_state(); + + // MainloopLoad <--> MMA Pipeline + typename Load2MmaPipeline::Params load2mma_pipeline_params; + if (warp_category == WarpCategory::MainloopLoadB) { + load2mma_pipeline_params.role = Load2MmaPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::MMA) { + load2mma_pipeline_params.role = Load2MmaPipeline::ThreadCategory::Consumer; + } + load2mma_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_loadB; + load2mma_pipeline_params.num_consumers = NumMMAThreads; + load2mma_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes_B; + load2mma_pipeline_params.initializing_warp = 8; + Load2MmaPipeline load2mma_pipeline(shared_storage.pipelines.mainloop.load2mma_pipeline, + load2mma_pipeline_params, + cluster_shape, + McastDirection::kCol, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Load2MmaPipelineState load2mma_pipeline_consumer_state; + Load2MmaPipelineState load2mma_pipeline_producer_state = cutlass::make_producer_start_state(); + + + // Transformation <--> MMA pipeline + typename Transform2MmaPipeline::Params transform2mma_pipeline_params; + if (warp_category == WarpCategory::Transformation) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::MMA) { + transform2mma_pipeline_params.role = Transform2MmaPipeline::ThreadCategory::Consumer; + } + transform2mma_pipeline_params.consumer_arv_count = 1; + transform2mma_pipeline_params.producer_arv_count = size(AtomThrShapeMNK{}) * NumTransformationThreads; + transform2mma_pipeline_params.initializing_warp = 2; + Transform2MmaPipeline transform2mma_pipeline(shared_storage.pipelines.mainloop.transform2mma_pipeline, + transform2mma_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Transform2MmaPipelineState transform2mma_pipeline_consumer_state; + Transform2MmaPipelineState transform2mma_pipeline_producer_state = cutlass::make_producer_start_state(); + + // MMA <--> Accumulator pipeline + typename Mma2AccumPipeline::Params mma2accum_pipeline_params; + if (warp_category == WarpCategory::MMA) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Producer; + } + else if (warp_category == WarpCategory::Epilogue) { + mma2accum_pipeline_params.role = Mma2AccumPipeline::ThreadCategory::Consumer; + } + mma2accum_pipeline_params.producer_arv_count = 1; + mma2accum_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + mma2accum_pipeline_params.initializing_warp = 6; + Mma2AccumPipeline mma2accum_pipeline(shared_storage.pipelines.mainloop.mma2accum_pipeline, + mma2accum_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask calculation + ); + + Mma2AccumPipelineState mma2accum_pipeline_consumer_state; + Mma2AccumPipelineState mma2accum_pipeline_producer_state = cutlass::make_producer_start_state(); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = 1; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // CLC pipeline + // Operates Scheduling Warp <--> All Warps + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumMainloopLoadBThreads + NumEpilogueThreads + + NumMMAThreads + NumTransformationThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + CLCPipelineState clc_pipeline_consumer_state; + CLCPipelineState clc_pipeline_producer_state = cutlass::make_producer_start_state(); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between transform, MMA, and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumTransformationThreads + NumMMAThreads + NumEpilogueThreads, + cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopLoadThreads + + NumMainloopLoadBThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0) + + NumTransformationThreads); + } + + + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + load2transform_pipeline.init_masks(cluster_shape, block_id_in_cluster, cutlass::McastDirection::kRow); + load2mma_pipeline.init_masks(cluster_shape, cutlass::McastDirection::kCol); + transform2mma_pipeline.init_masks(cluster_shape); + mma2accum_pipeline.init_masks(cluster_shape); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // Allocate accumulators + auto acc_shape = collective_mainloop.partition_accumulator_shape(); + auto bulk_tmem = TiledMma::make_fragment_C(append(acc_shape, + Int{})); + + // Tile transform inputs now to get the k tile count + auto transform_inputs = collective_mainloop.transform_init(params.mainloop, problem_shape_MNKL, bulk_tmem, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(transform_inputs); + + // Synchronization call. Blocks wait until barriers are initialized in shared memory. + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + do { + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(Load2TransformPipeline::Stages, k_tile_count); + + if(is_participant.main_loadA){ + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + } + + if (lane_predicate) { + if(is_participant.main_loadA){ + auto [load2transform_pipeline_producer_state_next, k_tile_iter_next] = collective_mainloop.load_A( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [load2transform_pipeline_producer_state_next_, unused_] = collective_mainloop.load_A( + params.mainloop, + load2transform_pipeline, + load2transform_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + load2transform_pipeline_producer_state = load2transform_pipeline_producer_state_next_; + } + + if(is_participant.main_loadB){ + auto [load2mma_pipeline_producer_state_next, k_tile_iter_next] = collective_mainloop.load_B( + params.mainloop, + load2mma_pipeline, + load2mma_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + load2mma_pipeline_producer_state = load2mma_pipeline_producer_state_next; + + auto [load2mma_pipeline_producer_state_next_, unused_] = collective_mainloop.load_B( + params.mainloop, + load2mma_pipeline, + load2mma_pipeline_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + load2mma_pipeline_producer_state = load2mma_pipeline_producer_state_next_; + + } + } + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + + if(is_participant.main_loadA){ + if (lane_predicate) { + load2transform_pipeline.producer_tail(load2transform_pipeline_producer_state); + } + } + if(is_participant.main_loadB){ + if (lane_predicate) { + load2mma_pipeline.producer_tail(load2mma_pipeline_producer_state); + } + } + + } + + else if (is_participant.sched) { + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipeline_producer_state = scheduler.advance_to_next_work( + clc_pipeline, + clc_pipeline_producer_state + ); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipeline_producer_state); + } + } + + else if (is_participant.transformation) { + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + auto [load2transform_pipeline_consumer_state_next, transform2mma_pipeline_producer_state_next] = collective_mainloop.transform( + load2transform_pipeline, + load2transform_pipeline_consumer_state, + transform2mma_pipeline, + transform2mma_pipeline_producer_state, + bulk_tmem, + transform_inputs, + k_tile_iter, k_tile_count + ); + transform2mma_pipeline_producer_state = transform2mma_pipeline_producer_state_next; + load2transform_pipeline_consumer_state = load2transform_pipeline_consumer_state_next; + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + } while (work_tile_info.is_valid()); + + transform2mma_pipeline.producer_tail(transform2mma_pipeline_producer_state); + } + + else if (is_participant.mma) { + + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + + auto mma_input_operands = collective_mainloop.mma_init(bulk_tmem, shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (is_mma_leader_cta) { + auto [load2mma_pipeline_consumer_state_next, transform2mma_pipeline_consumer_state_next, mma2accum_pipeline_producer_state_next] = collective_mainloop.mma( + load2mma_pipeline, + load2mma_pipeline_consumer_state, + transform2mma_pipeline, + transform2mma_pipeline_consumer_state, + mma2accum_pipeline, + mma2accum_pipeline_producer_state, + bulk_tmem, + mma_input_operands, + k_tile_count + ); + // Advance the mm2accum pipe + load2mma_pipeline_consumer_state = load2mma_pipeline_consumer_state_next; + transform2mma_pipeline_consumer_state = transform2mma_pipeline_consumer_state_next; + mma2accum_pipeline_producer_state = mma2accum_pipeline_producer_state_next; + } + } while (work_tile_info.is_valid()); + + // leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + mma2accum_pipeline.producer_tail(mma2accum_pipeline_producer_state); + } + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue + ); + + do_tail_load = true; + } + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocation + tmem_allocation_result_barrier.arrive_and_wait_unaligned(); + + auto accum_inputs = collective_mainloop.accum_init(bulk_tmem, typename CollectiveEpilogue::CopyOpT2R{}, typename CollectiveEpilogue::EpilogueTile{}); + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipeline_consumer_state + ); + + if (increment_pipe) { + ++clc_pipeline_consumer_state; + } + + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + mma2accum_pipeline.consumer_wait(mma2accum_pipeline_consumer_state); + + // Accumulators + Tensor accumulators = bulk_tmem(_,_,_,mma2accum_pipeline_consumer_state.index()); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + mma2accum_pipeline_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulators, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, mma2accum_pipeline_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + mma2accum_pipeline, + mma2accum_pipeline_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulators, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + do_tail_store = true; + + // Advance the mma2accum pipe + mma2accum_pipeline_consumer_state = mma2accum_pipeline_state_next; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + else { + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index d12241aa..07d00fb6 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -88,9 +88,7 @@ public: static_assert(cute::is_static::value); auto selected_cluster_shape = cutlass::detail::select_cluster_shape(cluster_shape_mnk, hw_info.cluster_shape); - auto cta_shape = cute::conditional_return>( - shape_div(tile_shape_mnk, atom_thr_shape_mnk), // Dynamic Cluster: For 2SM kernels, use CTA tile shape for the underlying scheduler - shape_div(tile_shape_mnk, selected_cluster_shape)); // Static Cluster: Blackwell builders expects TileShape to be Cluster's Tile Shape, Hopper doesn't + auto cta_shape = shape_div(tile_shape_mnk, atom_thr_shape_mnk); // For 2SM kernels, use CTA tile shape for the underlying scheduler dim3 problem_blocks = get_tiled_cta_shape_mnl( problem_shapes, diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp new file mode 100644 index 00000000..5d6f13ec --- /dev/null +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp @@ -0,0 +1,1319 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(typename ProblemShape::UnderlyingProblemShape{}) == 3 or rank(typename ProblemShape::UnderlyingProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using InternalStrideA = typename CollectiveMainloop::InternalStrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using InternalStrideB = typename CollectiveMainloop::InternalStrideB; + using LayoutSFA = typename CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename CollectiveMainloop::LayoutSFB; + using ElementSF = typename CollectiveMainloop::ElementSF; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using InternalStrideC = typename CollectiveEpilogue::InternalStrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using InternalStrideD = typename CollectiveEpilogue::InternalStrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = cute::conditional_t::Scheduler, + typename detail::TileSchedulerSelector< + TileSchedulerTag_, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler>; + + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr uint32_t MinTensorMapWorkspaceAlignment = 64; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopABLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopSFLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = 3 * NumThreadsPerWarp; // 3 warp + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopABPipeline = typename CollectiveMainloop::MainloopABPipeline; + using MainloopABPipelineState = typename CollectiveMainloop::MainloopABPipelineState; + + using MainloopSFPipeline = typename CollectiveMainloop::MainloopSFPipeline; + using MainloopSFPipelineState = typename CollectiveMainloop::MainloopSFPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cute::conditional_t, + cutlass::PipelineAsync>; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cute::conditional_t, + cutlass::PipelineEmpty>; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + static constexpr int EpilogueWarpRegs = 248; + static constexpr int NonEpilogueWarpRegs = 128; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(8) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorMapStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorMapStorage = typename CollectiveEpilogue::TensorMapStorage; + using MainloopTensorMapStorage = typename CollectiveMainloop::TensorMapStorage; + alignas(128) EpilogueTensorMapStorage epilogue; + alignas(128) MainloopTensorMapStorage mainloop; + } tensormaps; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopABLoad = 2, + MainloopSFLoad = 3, + Epilogue = 4, // Warps [4-8) + EpilogueLoad = 8, + Unused = 9 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_ab_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_sf_load = false; + uint32_t unused = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + constexpr uint32_t NumEpilogueSubTiles = 1; + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + ProblemShape problem_shapes = args.problem_shape; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (IsGroupedGemmKernel && sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + else if (!IsGroupedGemmKernel && sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(problem_shapes, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + + void* mainloop_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveMainloop::get_workspace_size(problem_shapes, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, problem_shapes.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + + TileSchedulerParams scheduler; + if constexpr (IsGroupedGemmKernel) { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace); + } + else { + scheduler = TileScheduler::to_underlying_arguments( + problem_shapes.get_host_problem_shape(), TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ); + } + + return { + args.mode, + problem_shapes, + CollectiveMainloop::to_underlying_arguments(problem_shapes, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(problem_shapes, args.epilogue, epilogue_workspace), + scheduler, + args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = true; + if constexpr (IsGroupedGemmKernel) { + // Group GEMM currently only supports rank-3 problem shapes + implementable &= (args.mode == GemmUniversalMode::kGrouped && rank(typename ProblemShape::UnderlyingProblemShape{}) == 3); + } else { + implementable &= (args.mode == GemmUniversalMode::kArray && rank(typename ProblemShape::UnderlyingProblemShape{}) == 4); + } + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Mainloop, Epilogue or Scheduler don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Dynamic Cluster or Preferred Cluster don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); + return implementable; + } + + constexpr bool IsBlockscaled = !cute::is_void_v; + if constexpr (IsBlockscaled) { + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + } + else { + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + } + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + constexpr uint32_t NumEpilogueSubTiles = 1; + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); + + // Mainloop + workspace_size += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinTensorMapWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + constexpr uint32_t NumEpilogueSubTiles = 1; + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Mainloop + status = CollectiveMainloop::initialize_workspace(args.problem_shape, args.mainloop, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveMainloop::get_workspace_size(args.problem_shape, args.mainloop, args.hw_info.sm_count); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape.get_host_problem_shape(0), args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinTensorMapWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // NOTE: cluster_shape here is the major cluster shape, not fallback one + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + dim3 grid_shape; + if constexpr (IsGroupedGemmKernel) { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + else { + grid_shape = TileScheduler::get_grid_shape( + params.scheduler, + params.problem_shape.get_host_problem_shape(), + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + return grid_shape; + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + +private: + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_dealloc() { + cutlass::arch::warpgroup_reg_dealloc(); + } + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_alloc() { + cutlass::arch::warpgroup_reg_alloc(); + } + +public: + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + auto problem_shape = params.problem_shape; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = (warp_idx >= static_cast(WarpCategory::Epilogue) && warp_idx < static_cast(WarpCategory::EpilogueLoad)) ? WarpCategory::Epilogue : + WarpCategory(warp_idx); + if (warp_idx > static_cast(WarpCategory::EpilogueLoad)) { + warp_category = WarpCategory::Unused; + } + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = IsSchedDynamicPersistent ? (cta_rank_in_cluster == 0) : true; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopABLoad), // main_ab_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopSFLoad), // main_sf_load + (warp_category == WarpCategory::Unused) // empty + }; + + // Mainloop Load pipeline + typename MainloopABPipeline::Params mainloop_ab_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; + // Initialize the barrier for TMA load prefetch + } + if (WarpCategory::MMA == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; + } + mainloop_ab_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_ab_load; + mainloop_ab_pipeline_params.transaction_bytes = CollectiveMainloop::ABTmaTransactionBytes; + mainloop_ab_pipeline_params.initializing_warp = 0; + MainloopABPipeline mainloop_ab_pipeline(shared_storage.pipelines.mainloop.pipeline_ab, + mainloop_ab_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Mainloop SF load pipeline + typename MainloopSFPipeline::Params mainloop_sf_pipeline_params; + if (WarpCategory::MainloopSFLoad == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Consumer; + } + mainloop_sf_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_sf_load; + mainloop_sf_pipeline_params.transaction_bytes = CollectiveMainloop::SFTransactionBytes; + mainloop_sf_pipeline_params.initializing_warp = 0; + MainloopSFPipeline mainloop_sf_pipeline(shared_storage.pipelines.mainloop.pipeline_sf, + mainloop_sf_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopABLoad || warp_category == WarpCategory::MainloopSFLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopABLoadThreads + NumMainloopSFLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = IsSchedDynamicPersistent ? + CLCPipeline::ThreadCategory::ProducerConsumer : + CLCPipeline::ThreadCategory::Producer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + + clc_pipeline_params.initializing_warp = 1; + clc_pipeline_params.producer_arv_count = 1; + + if constexpr (IsSchedDynamicPersistent) { + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + } + else { + clc_pipeline_params.consumer_arv_count = NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumEpilogueThreads + NumMMAThreads; + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += NumEpilogueLoadThreads; + } + } + // Now declare the pipeline outside the if constexpr + CLCPipeline clc_pipeline = [&]() { + if constexpr (IsSchedDynamicPersistent) { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + } + else { + return CLCPipeline(shared_storage.pipelines.clc, clc_pipeline_params); + } + }(); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if constexpr (IsSchedDynamicPersistent) { + if (WarpCategory::MainloopABLoad == warp_category || WarpCategory::MainloopSFLoad== warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopSFLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + } + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if constexpr(!IsOverlappingAccum) { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } + else { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (WarpCategory::MMA == warp_category && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + MainloopABPipelineState mainloop_ab_pipe_consumer_state; + MainloopABPipelineState mainloop_ab_pipe_producer_state = cutlass::make_producer_start_state(); + + MainloopSFPipelineState mainloop_sf_pipe_consumer_state; + MainloopSFPipelineState mainloop_sf_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + int32_t sm_id = static_cast(cutlass::arch::SmId()); + + // Calculate mask after cluster barrier arrival + mainloop_ab_pipeline.init_masks(cluster_shape); + mainloop_sf_pipeline.init_masks(cluster_shape); + accumulator_pipeline.init_masks(cluster_shape); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + TiledMma tiled_mma; + ThrMMA cta_mma = tiled_mma.get_slice(cta_coord_v); + auto acc_shape = partition_shape_C(tiled_mma, take<0,2>(TileShape{})); + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + + pipeline_init_wait(cluster_size); + + if constexpr (IsGroupedGemmKernel) { + if (not work_tile_info.is_valid()) { + // When problem shapes are only on device, the grid launched may be larger than the total number of blocks across groups + return; + } + // In case user wants to engage less SMs than available on device + sm_id = blockIdx.x + (blockIdx.y * gridDim.x); + } + + auto problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + + if (is_participant.main_ab_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_ab_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + shared_storage.tensormaps.mainloop, + params.hw_info.sm_count, sm_id); + Tensor gA_mkl = get<0>(load_inputs); + // Fetch a copy of tensormaps for the CTA from Params + auto input_tensormaps = get(load_inputs); + + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize + bool did_batch_change = true; + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + + do { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + if (did_batch_change) { + collective_mainloop.tensormaps_perform_update_ab( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape, + curr_batch + ); + } + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(MainloopABPipeline::Stages, k_tile_count); + // Problem Shape and therefore strides that we construct are [M,N,K,L], but since here for the TMA loads + // we are managing TMA descriptors to change batches, we need to neglect the L mode + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_prologue, + did_batch_change, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter_next, k_tile_count - k_tile_prologue, + false, /* did_batch_change - prologue loads handle tensormap acquire */ + enable_prefetch ? k_tile_count - k_tile_prologue : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); + + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_ab_pipeline, mainloop_ab_pipe_producer_state); + + } + + else if (is_participant.sched) { + set_warpgroup_reg_dealloc(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + else { + do { + auto [next_work_tile_info, increment_pipe] = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + work_tile_info = next_work_tile_info; + if (increment_pipe) { + ++clc_pipe_producer_state; + } + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + else if (is_participant.main_sf_load) { + set_warpgroup_reg_dealloc(); + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_sf_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, + shared_storage.tensormaps.mainloop, + params.hw_info.sm_count, sm_id, work_tile_info.L_idx); + + auto gA_mkl = collective_mainloop.get_mkl_shape_tensor(problem_shape_MNKL); + auto input_tensormaps = get(load_inputs); + + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + do { + int32_t curr_batch = idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); // Usually just returns work_tile_info.L_idx; + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + if (did_batch_change) { + collective_mainloop.tensormaps_perform_update_sf( + shared_storage.tensormaps.mainloop, + params.mainloop, + input_tensormaps, + problem_shape, + curr_batch + ); + } + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_prologue = min(MainloopSFPipeline::Stages/2, k_tile_count); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); // maybe we could use ceil_div(gSFA_mkl, 2); + auto cta_coord_mnk = append<4>(make_coord(get<0>(cta_coord_mnkl), get<1>(cta_coord_mnkl), get<2>(cta_coord_mnkl)), Int<0>{}); + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnk, + k_tile_iter, k_tile_prologue, + did_batch_change, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue, + false, /* did_batch_change - prologue loads handle tensormap acquire */ + enable_prefetch ? k_tile_count - k_tile_prologue : 0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != idx2crd(work_tile_info.L_idx, shape<4>(gA_mkl)); + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_sf_pipeline, mainloop_sf_pipe_producer_state); + + } + + else if (is_participant.mma) { + set_warpgroup_reg_dealloc(); + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + int tmem_non_accumulator_base = tmem_base_ptr + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + shared_storage.tensors.mainloop, + tmem_non_accumulator_base /*Start SF TMEM allocation after the accumulator*/); + + do { + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(work_tile_info.L_idx), 1); + } + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if constexpr (!IsOverlappingAccum) { + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + } + int stage_idx = (IsOverlappingAccum) ? (accumulator_pipe_producer_state.phase() ^ 1) : (accumulator_pipe_producer_state.index()); + Tensor accumulator = accumulators(_,_,_, stage_idx); + + if (is_mma_leader_cta) { + auto [mainloop_ab_pipe_consumer_state_next, mainloop_sf_pipe_consumer_state_next] = collective_mainloop.mma( + cute::make_tuple(mainloop_ab_pipeline, mainloop_sf_pipeline, accumulator_pipeline), + cute::make_tuple(mainloop_ab_pipe_consumer_state, mainloop_sf_pipe_consumer_state, accumulator_pipe_producer_state), + accumulator, + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + + mainloop_ab_pipe_consumer_state = mainloop_ab_pipe_consumer_state_next; + mainloop_sf_pipe_consumer_state = mainloop_sf_pipe_consumer_state_next; + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + + + ++accumulator_pipe_producer_state; + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + if constexpr (!IsOverlappingAccum) { + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + } + else { + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + + // Fetch a copy of tensormaps for the CTA from Params + auto epi_load_tensormap = get<0>(collective_epilogue.load_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + + bool did_batch_change = true; + constexpr bool IsEpiLoad = true; + + do { + int32_t curr_batch = work_tile_info.L_idx; + if (did_batch_change) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_load_tensormap, + problem_shape, + curr_batch + ); + } + + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + cute::make_tuple(epi_load_tensormap, did_batch_change), + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + set_warpgroup_reg_alloc(); + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + + auto warp_idx_in_epi = canonical_warp_idx_sync() - static_cast(WarpCategory::Epilogue); + bool do_tail_store = false; + // Fetch a copy of tensormaps for the CTA from Params + auto epi_store_tensormap = get<0>(collective_epilogue.store_init( + params.epilogue, shared_storage.tensormaps.epilogue, params.hw_info.sm_count, sm_id)); + // Initial batch's tensor address update + // Even the first tile for a CTA can be from any of the batches. + // And during initialization of the first TMA descriptor on host, we don't initialize to the first batch due to that args value being device-only. + bool did_batch_change = true; + constexpr bool IsEpiLoad = false; + do { + int32_t curr_batch = work_tile_info.L_idx; + + + if (did_batch_change && warp_idx_in_epi == 0) { + collective_epilogue.template tensormaps_perform_update( + shared_storage.tensormaps.epilogue, + params.epilogue, + epi_store_tensormap, + problem_shape, + curr_batch + ); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Accumulator stage slice after making sure allocation has been performed + int acc_stage = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_consumer_state.phase(); + } + else { + return accumulator_pipe_consumer_state.index(); + } + }(); + + // Fusions may need problem shape for the current group + if constexpr (IsGroupedGemmKernel) { + problem_shape_MNKL = append<4>(problem_shape.get_problem_shape(curr_batch), 1); + } + + // Epilogue and write to gD + // + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + collective_mainloop.slice_accumulator(accumulators, acc_stage), + shared_storage.tensors.epilogue, + cute::make_tuple(epi_store_tensormap, did_batch_change) + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + + do_tail_store |= TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // For subsequent tiles, check if batch changes and therefore, we need tensormap updates + did_batch_change = curr_batch != work_tile_info.L_idx; + } while (work_tile_info.is_valid()); + + if constexpr (IsOverlappingAccum) { + // Signal to peer MMA that Full TMEM alloc can be deallocated + if constexpr (has_mma_peer_cta) { + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank); + } + tmem_deallocation_result_barrier.arrive(); + } + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + + } + + else { + set_warpgroup_reg_dealloc(); + } + + } + + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp new file mode 100644 index 00000000..ae93b2ff --- /dev/null +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp @@ -0,0 +1,1112 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/arch/grid_dependency_control.h" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm100_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" +#include "cute/atom/mma_atom.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using LayoutSFA = typename CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename CollectiveMainloop::LayoutSFB; + using ElementSF = typename CollectiveMainloop::ElementSF; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsNoSmemEpilogue = is_same_v; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopABLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopSFLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + static constexpr uint32_t NumEpilogueLoadThreads = IsNoSmemEpilogue ? 0 : NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEmptyThreads = IsNoSmemEpilogue ? 0 : 3 * NumThreadsPerWarp; // 3 warp + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads + NumEmptyThreads; + + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopABPipeline = typename CollectiveMainloop::MainloopABPipeline; + using MainloopABPipelineState = typename CollectiveMainloop::MainloopABPipelineState; + + using MainloopSFPipeline = typename CollectiveMainloop::MainloopSFPipeline; + using MainloopSFPipelineState = typename CollectiveMainloop::MainloopSFPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + static constexpr int EpilogueWarpRegs = 248; + static constexpr int NonEpilogueWarpRegs = 128; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(8) arch::ClusterBarrier tmem_dealloc; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopABLoad = 2, + MainloopSFLoad = 3, + Epilogue = 4, // Warps [4-8) + EpilogueLoad = 8, + Unused = 9 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_ab_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + uint32_t main_sf_load = false; + uint32_t unused = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + constexpr int NumEpilogueSubTiles = 1; + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster shape check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + } + else { + // Special cluster check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + constexpr int NumEpilogueSubTiles = 1; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr int NumEpilogueSubTiles = 1; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + +private: + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_dealloc() { + if constexpr (not IsNoSmemEpilogue) { + cutlass::arch::warpgroup_reg_dealloc(); + } + } + + static constexpr + CUTLASS_DEVICE + void set_warpgroup_reg_alloc() { + if constexpr (not IsNoSmemEpilogue) { + cutlass::arch::warpgroup_reg_alloc(); + } + } + +public: + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = (warp_idx >= static_cast(WarpCategory::Epilogue) && warp_idx < static_cast(WarpCategory::EpilogueLoad)) ? WarpCategory::Epilogue : + WarpCategory(warp_idx); + if (warp_idx > static_cast(WarpCategory::EpilogueLoad)) { + warp_category = WarpCategory::Unused; + } + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, cute::cluster_shape()); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopABLoad), // main_ab_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue), // epilogue + (warp_category == WarpCategory::MainloopSFLoad), // main_sf_load + (warp_category == WarpCategory::Unused) // empty + }; + + // Mainloop Load pipeline + typename MainloopABPipeline::Params mainloop_ab_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Producer; + // Initialize the barrier for TMA load prefetch + } + if (WarpCategory::MMA == warp_category) { + mainloop_ab_pipeline_params.role = MainloopABPipeline::ThreadCategory::Consumer; + } + mainloop_ab_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_ab_load; + mainloop_ab_pipeline_params.transaction_bytes = CollectiveMainloop::ABTmaTransactionBytes; + mainloop_ab_pipeline_params.initializing_warp = 0; + MainloopABPipeline mainloop_ab_pipeline(shared_storage.pipelines.mainloop.pipeline_ab, + mainloop_ab_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Mainloop SF load pipeline + typename MainloopSFPipeline::Params mainloop_sf_pipeline_params; + if (WarpCategory::MainloopSFLoad == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_sf_pipeline_params.role = MainloopSFPipeline::ThreadCategory::Consumer; + } + mainloop_sf_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_sf_load; + mainloop_sf_pipeline_params.transaction_bytes = CollectiveMainloop::SFTransactionBytes; + mainloop_sf_pipeline_params.initializing_warp = 0; + MainloopSFPipeline mainloop_sf_pipeline(shared_storage.pipelines.mainloop.pipeline_sf, + mainloop_sf_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopABLoad || warp_category == WarpCategory::MainloopSFLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopABLoadThreads + NumMainloopSFLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopABLoadThreads + NumMainloopSFLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (WarpCategory::MMA == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + accumulator_pipeline_params.initializing_warp = 2; + AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopABLoad == warp_category || WarpCategory::MainloopSFLoad== warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + + clc_throttle_pipeline_params.producer_arv_count = NumMainloopSFLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if constexpr(!IsOverlappingAccum) { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } + else { + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (WarpCategory::MMA == warp_category && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + MainloopABPipelineState mainloop_ab_pipe_consumer_state; + MainloopABPipelineState mainloop_ab_pipe_producer_state = cutlass::make_producer_start_state(); + + MainloopSFPipelineState mainloop_sf_pipe_consumer_state; + MainloopSFPipelineState mainloop_sf_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + mainloop_ab_pipeline.init_masks(cluster_shape); + mainloop_sf_pipeline.init_masks(cluster_shape); + accumulator_pipeline.init_masks(cluster_shape); + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + // + // TMEM "Allocation" + // + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + TiledMma tiled_mma; + ThrMMA cta_mma = tiled_mma.get_slice(cta_coord_v); + auto acc_shape = partition_shape_C(tiled_mma, take<0,2>(TileShape{})); + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + +#if 1 + pipeline_init_wait(cluster_size); + + if (is_participant.main_ab_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_ab_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(load_inputs); + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + + do { + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_prologue = min(MainloopABPipeline::Stages, k_tile_count); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next; + + if constexpr (not IsNoSmemEpilogue) { + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_ab( + params.mainloop, + mainloop_ab_pipeline, + mainloop_ab_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue, + enable_prefetch ? k_tile_count - k_tile_prologue : 0 + ); + mainloop_ab_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_ab_pipeline, mainloop_ab_pipe_producer_state); + + } + + else if (is_participant.sched) { + set_warpgroup_reg_dealloc(); + + if constexpr (IsSchedDynamicPersistent) { + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + cutlass::arch::wait_on_dependent_grids(); + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + } + } + + + else if (is_participant.main_sf_load) { + set_warpgroup_reg_dealloc(); + bool do_load_order_arrive = is_epi_load_needed; + auto load_inputs = collective_mainloop.load_sf_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + auto tmp = collective_mainloop.load_ab_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + Tensor gA_mkl = get<0>(tmp); // just to get k_tile_count or maybe we could use ceil_div(shape<3>(gSFA_mkl), 2); + bool requires_clc_query = true; + // 2cta: 4x4/4x2/2x4 enable the PF + bool enable_prefetch = shape<0>(AtomThrShapeMNK{}) == 2 and + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 4) or + (size<0>(cluster_shape) == 4 and size<1>(cluster_shape) == 2) or + (size<0>(cluster_shape) == 2 and size<1>(cluster_shape) == 4); + do { + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_prologue = min(MainloopSFPipeline::Stages/2, k_tile_count); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); // maybe we could use ceil_div(gSFA_mkl, 2); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, k_tile_iter_next] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue, + enable_prefetch ? k_tile_count : 0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next; + + if constexpr (not IsNoSmemEpilogue) { + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + } + + auto [mainloop_producer_state_next_, unused_] = collective_mainloop.load_sf( + params.mainloop, + mainloop_sf_pipeline, + mainloop_sf_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue, + enable_prefetch ? k_tile_count - k_tile_prologue :0 + ); + mainloop_sf_pipe_producer_state = mainloop_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + collective_mainloop.load_tail(mainloop_sf_pipeline, mainloop_sf_pipe_producer_state); + + } + + + else if (is_participant.mma) { + set_warpgroup_reg_dealloc(); + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + int tmem_non_accumulator_base = tmem_base_ptr + cutlass::detail::find_tmem_tensor_col_offset(accumulators); + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + shared_storage.tensors.mainloop, + tmem_non_accumulator_base /*Start SF TMEM allocation after the accumulator*/); + + do { + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + // Wait for tmem accumulator buffer to become empty with a flipped phase + if constexpr (!IsOverlappingAccum) { + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } + } + int stage_idx = (IsOverlappingAccum) ? (accumulator_pipe_producer_state.phase() ^ 1) : (accumulator_pipe_producer_state.index()); + Tensor accumulator = accumulators(_,_,_, stage_idx); + + if (is_mma_leader_cta) { + auto [mainloop_ab_pipe_consumer_state_next, mainloop_sf_pipe_consumer_state_next] = collective_mainloop.mma( + cute::make_tuple(mainloop_ab_pipeline, mainloop_sf_pipeline, accumulator_pipeline), + cute::make_tuple(mainloop_ab_pipe_consumer_state, mainloop_sf_pipe_consumer_state, accumulator_pipe_producer_state), + accumulator, + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + + mainloop_ab_pipe_consumer_state = mainloop_ab_pipe_consumer_state_next; + mainloop_sf_pipe_consumer_state = mainloop_sf_pipe_consumer_state_next; + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } + ++accumulator_pipe_producer_state; + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + if constexpr (!IsOverlappingAccum) { + // Leader MMA waits for leader + peer epilogues to release accumulator stage + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + } + else { + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (not IsNoSmemEpilogue and is_participant.epi_load) { + set_warpgroup_reg_dealloc(); + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + set_warpgroup_reg_alloc(); + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + accumulators.data() = tmem_base_ptr; + + + bool do_tail_store = false; + do { + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + int stage_idx = [&] () { + if constexpr (IsOverlappingAccum) { + return accumulator_pipe_consumer_state.phase(); + } + else { + return accumulator_pipe_consumer_state.index(); + } + }(); + + // Accumulator + Tensor accumulator = accumulators(_,_,_,stage_idx); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + accumulator_pipe_consumer_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accumulator, + accumulator_pipeline, + accumulator_pipe_consumer_state, + typename CollectiveEpilogue::CopyOpT2R{} + ); + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next, acc_state_next] = collective_epilogue.template store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + accumulator_pipeline, + accumulator_pipe_consumer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accumulator, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + accumulator_pipe_consumer_state = acc_state_next; + do_tail_store = true; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + + if constexpr (IsOverlappingAccum) { + // Signal to peer MMA that Full TMEM alloc can be deallocated + if constexpr (has_mma_peer_cta) { + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank); + } + tmem_deallocation_result_barrier.arrive(); + } + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } + + else { + set_warpgroup_reg_dealloc(); + } +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 4f5723da..92bd5536 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -411,9 +411,11 @@ public: using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index f33f4685..6ac24d34 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -423,9 +423,11 @@ public: using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index 5bdaba15..5b558005 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -270,9 +270,11 @@ public: using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 0b12aacd..d398d1f2 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -342,9 +342,11 @@ public: using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index fc4f5fc1..1326f390 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -354,9 +354,11 @@ public: using namespace cute; using X = Underscore; -#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200)) -# define ENABLE_SM90_KERNEL_LEVEL 1 -#endif +# if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL) || defined(__CUDA_ARCH_FEAT_SM121_ALL) ||\ + CUDA_ARCH_CONDITIONAL_OR_FAMILY(1200) || CUDA_ARCH_CONDITIONAL_OR_FAMILY(1210)) +# define ENABLE_SM90_KERNEL_LEVEL 1 +# endif + // Any Tensor Op MMA Atom in the ISA is arch conditional. #if ! defined(ENABLE_SM90_KERNEL_LEVEL) printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); diff --git a/include/cutlass/gemm/kernel/tile_scheduler.hpp b/include/cutlass/gemm/kernel/tile_scheduler.hpp index aa7bd0dc..d78bc4b0 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -298,6 +298,66 @@ struct TileSchedulerSelector< using Scheduler = StaticPersistentTileScheduler100; }; +template +struct TileSchedulerSelector< + PersistentScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount> { + using Scheduler = PersistentTileSchedulerSm100< + ClusterShape, + SchedulerPipelineStageCount>; +}; + +// Ptr-Array kernel may provide a specialized ArrayProblemShape type +template +struct TileSchedulerSelector< + PersistentScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount, + ProblemShape> { + using Scheduler = PersistentTileSchedulerSm100< + ClusterShape, + SchedulerPipelineStageCount>; +}; + +// SM103 Group tile scheduler +template < + class TileShape, + class ClusterShape, + uint32_t SchedulerPipelineStageCount, + class GroupProblemShape +> +struct TileSchedulerSelector< + GroupScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount, + GroupProblemShape + > { + using Scheduler = PersistentTileSchedulerSm100Group; +}; + +template +struct TileSchedulerSelector< + StreamKScheduler, + arch::Sm103, + TileShape, + ClusterShape, + SchedulerPipelineStageCount> { + using Scheduler = PersistentTileSchedulerSm100StreamK< + TileShape, + ClusterShape, + SchedulerPipelineStageCount>; +}; + // Default (void) for Sm120 maps to PersistentTileSchedulerSm100 template struct TileSchedulerSelector< diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 030e5845..43047eae 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -35,14 +35,13 @@ */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif -#include "cutlass/cutlass.h" #include "cutlass/numeric_size.h" #include "cutlass/platform/platform.h" @@ -206,6 +205,12 @@ using int2b_t = integer_subbyte<2, true>; /// 2-bit Unsigned integer type using uint2b_t = integer_subbyte<2, false>; +/// 3-bit Integer type +using int3b_t = integer_subbyte<3, true>; + +/// 3-bit Unsigned integer type +using uint3b_t = integer_subbyte<3, false>; + /// 4-bit Integer type using int4b_t = integer_subbyte<4, true>; diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h index 32a6ee0d..99e3353f 100644 --- a/include/cutlass/layout/permute.h +++ b/include/cutlass/layout/permute.h @@ -38,9 +38,8 @@ computation lies in operator() with private member variables {col_permute_, row_permute_ and stride_} as new addresses after permute op. */ #pragma once - -#include #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index faf64275..9e8a354e 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -39,9 +39,8 @@ defined in cutlass/tensor_ref.h. */ #pragma once - -#include #include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 1d5856f9..7aad6c24 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -3901,16 +3901,16 @@ struct NumericArrayConverter { static result_type convert(source_type const & source) { #if defined(CUDA_PTX_FP4FP6_CVT_ENABLED) - unsigned out; + uint16_t out; asm volatile( \ "{\n" \ ".reg .b8 byte0;\n" \ ".reg .b8 byte1;\n" \ "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" \ - "mov.b32 %0, {byte0, byte1, 0, 0};\n" \ + "mov.b16 %0, {byte0, byte1};\n" \ "}" \ - : "=r"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); + : "=h"(out) : "f"(source[0]), "f"(source[1]), "f"(source[2]), "f"(source[3])); return reinterpret_cast(out); #else diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index b79a3d25..0d814ed2 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -43,6 +43,7 @@ #include "cutlass/tfloat32.h" #include "cutlass/float8.h" #include "cutlass/uint128.h" +#include "cutlass/uint256.h" #include "cutlass/exmy_base.h" #include "cutlass/float_subbyte.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm100_pipeline.hpp b/include/cutlass/pipeline/sm100_pipeline.hpp index 53bc9199..4014bd00 100644 --- a/include/cutlass/pipeline/sm100_pipeline.hpp +++ b/include/cutlass/pipeline/sm100_pipeline.hpp @@ -140,6 +140,8 @@ public: int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); } @@ -320,6 +322,24 @@ public: } } + template + CUTLASS_DEVICE + PipelineTmaTransformAsync(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + // Helper function to initialize barriers template static @@ -334,9 +354,11 @@ public: static constexpr bool IsDynamicCluster = not cute::is_static_v; static_assert(IsDynamicCluster or ((cute::size<0>(cluster_shape) % cute::size<0>(atom_thr_shape) == 0) && (cute::size<1>(cluster_shape) % cute::size<1>(atom_thr_shape) == 0))); - uint32_t const num_consumer_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_per_cluster = cute::ceil_div(params.num_consumers, static_cast(NumThreadsPerWarpGroup)); uint32_t const multicast_consumer_arrival_count = ((cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1) * num_consumer_per_cluster; + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -344,8 +366,31 @@ public: } template + static CUTLASS_DEVICE - void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const num_consumer_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) * num_consumer_per_cluster : // Mcast with row ctas + (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) * num_consumer_per_cluster; // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + + } + cutlass::arch::fence_barrier_init(); + } + + template + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster(), McastDirection mcast_dir = McastDirection::kRowCol) { // Calculate consumer mask if (params_.role == ThreadCategory::Consumer) { // Logic to optimally schedule Empty Arrives @@ -374,10 +419,25 @@ public: // STEP 2: Find if this dst block-id needs an arrival for this problem is_signaling_thread_ &= dst_blockid_ < cluster_size; - is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id_in_cluster, cluster_shape); + if(mcast_dir == McastDirection::kRowCol){ + is_signaling_thread_ &= is_same_row_or_col(dst_blockid_, block_id_in_cluster, cluster_shape); + } + if(mcast_dir == McastDirection::kRow){ + is_signaling_thread_ &= is_same_row(dst_blockid_, block_id_in_cluster, cluster_shape); + } } } + template + CUTLASS_DEVICE + bool is_same_row(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { + return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) + // If we are in the same cluster column and using 2CTA MMA, only odd or only even CTAs sync with each other + && ((dst_block_id % cute::size<0>(cluster_shape)) % cute::size<0>(AtomThrShape_MNK{}) == + block_id.x % cute::size<0>(AtomThrShape_MNK{})) + ); + } + template CUTLASS_DEVICE bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { @@ -504,7 +564,8 @@ public: auto atom_thr_shape = AtomThrShape_MNK{}; uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; - + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -525,6 +586,8 @@ public: cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -893,6 +956,8 @@ public: int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); } @@ -910,6 +975,8 @@ public: int warp_idx = canonical_warp_idx_sync(); if (warp_idx == params.initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( full_barrier_ptr_, empty_barrier_ptr_, params_.producer_arv_count, params_.consumer_arv_count); } @@ -1102,7 +1169,7 @@ public: /////////////////////////////////////////////////////////////////////////////////////////////////// // // TMA (producer - consumer) Async Pipeline classes for Blackwell Sparse UMMA -// This is designed for the parttern that kernel has two different staged tensors. (AB and metadata) +// This is designed for the pattern that kernel has two different staged tensors. (AB and metadata) // /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 0828c1ea..aae17d98 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -309,13 +309,14 @@ public: if (is_initializing_warp) { // Barrier FULL and EMPTY init uint32_t const producer_arv_cnt = params.num_producers; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = cute::ceil_div(params.num_consumers, static_cast(NumThreadsPerWarpGroup)); uint32_t multicast_consumer_arrival_count = params.num_consumers; // If cluster_size is 1 if (cute::size(cluster_shape) > 1) { multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * num_consumer_warpgroups_per_cluster; } - + CUTLASS_ASSERT(multicast_consumer_arrival_count > 0 && "Multicast consumer arrival count must be non-zero"); + CUTLASS_ASSERT(producer_arv_cnt > 0 && "Producer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); } @@ -804,6 +805,8 @@ public: if (is_initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( full_barrier_ptr, empty_barrier_ptr, params.producer_arv_count, params.consumer_arv_count); } @@ -1043,6 +1046,8 @@ public: is_initializing_warp = (warp_idx == params.initializing_warp); if (is_initializing_warp) { // Barrier FULL and EMPTY init + CUTLASS_ASSERT(params.producer_arv_count > 0 && "Producer arrival count must be non-zero"); + CUTLASS_ASSERT(params.consumer_arv_count > 0 && "Consumer arrival count must be non-zero"); cutlass::arch::detail::initialize_barrier_array_pair_aligned( storage.full_barrier_, storage.empty_barrier_, params.producer_arv_count, params.consumer_arv_count); } @@ -1299,6 +1304,7 @@ public: // Barrier FULL, EMPTY init if (warp_idx == params.initializing_warp) { int arv_cnt = params.group_size; + CUTLASS_ASSERT(arv_cnt > 0 && "Arrive count must be non-zero"); constexpr int Stages = Depth * Length; cutlass::arch::detail::initialize_barrier_array_aligned( barrier_ptr_, arv_cnt); @@ -1307,6 +1313,7 @@ public: int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); + CUTLASS_ASSERT(params.group_size > 0 && "Group size must be non-zero"); // Barrier FULL, EMPTY init // Init is done only by the one elected thread of the block diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 939451a2..21d6ed16 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -98,13 +98,13 @@ //----------------------------------------------------------------------------- // Dependencies //----------------------------------------------------------------------------- - +#include #if defined(__CUDACC_RTC__) -#include -#include -#include -#include -#include +#include CUDA_STD_HEADER(type_traits) +#include CUDA_STD_HEADER(utility) +#include CUDA_STD_HEADER(cstddef) +#include CUDA_STD_HEADER(cstdint) +#include CUDA_STD_HEADER(limits) #else #include #include @@ -128,7 +128,6 @@ #endif #include -#include #endif @@ -523,7 +522,7 @@ using std::is_trivially_copyable; #endif -#if (201703L <=__cplusplus) +#if (CUTLASS_CXX17_OR_LATER) /// std::is_unsigned_v using CUTLASS_STL_NAMESPACE::is_integral_v; @@ -596,14 +595,6 @@ struct alignment_of { enum { value = 16 }; }; template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> -struct alignment_of { - enum { value = 16 }; -}; -template <> struct alignment_of { enum { value = 16 }; }; @@ -615,6 +606,15 @@ template <> struct alignment_of { enum { value = 16 }; }; + +template <> +struct alignment_of { + enum { value = 16 }; +}; +template <> +struct alignment_of { + enum { value = 16 }; +}; template <> struct alignment_of { enum { value = 16 }; @@ -628,6 +628,7 @@ struct alignment_of { enum { value = 16 }; }; + // Specializations for volatile/const qualified types template struct alignment_of : alignment_of {}; diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index 0241a6fd..c3867c57 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -33,16 +33,15 @@ of boolean predicates. */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #endif -#include +#include CUDA_STD_HEADER(cassert) -#include "cutlass/cutlass.h" #include "cutlass/platform/platform.h" namespace cutlass { diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index 295eaa68..68896d6b 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -33,9 +33,9 @@ \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. */ #pragma once - +#include "cutlass/cutlass.h" #if defined(__CUDACC_RTC__) -#include +#include CUDA_STD_HEADER(cstdint) #else #include #include @@ -44,7 +44,6 @@ #include #endif -#include "cutlass/cutlass.h" /// Optionally enable GCC's built-in type #if (defined(__x86_64) || defined (__aarch64__)) && !(defined(__CUDA_ARCH__) && ((__CUDACC_VER_MAJOR__ <= 10) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ <= 4)))) && defined(__GNUC__) diff --git a/include/cutlass/uint256.h b/include/cutlass/uint256.h new file mode 100644 index 00000000..36578535 --- /dev/null +++ b/include/cutlass/uint256.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines an unsigned 256b integer. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#include +#include +#include +#include +#endif +#include "cutlass/uint128.h" + +namespace cutlass { + +///! Unsigned 256b integer type +struct alignas(32) uint256_t { + /// Size of one part of the uint's storage in bits + static constexpr int storage_bits_ = 128; + + struct hilo { + uint128_t lo; + uint128_t hi; + }; + + // Use a union to store either low and high parts. + union { + struct hilo hilo_; + }; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + uint256_t() : hilo_{uint128_t{}, uint128_t{}} {} + + /// Constructor from uint128 + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_) : hilo_{lo_, uint128_t{}} {} + + /// Constructor from two 128b unsigned integers + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_, uint128_t hi_) : hilo_{lo_, hi_} {} + + /// Lossily cast to uint128_t + CUTLASS_HOST_DEVICE + explicit operator uint128_t() const { + return hilo_.lo; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/media/docs/cpp/cute/02_layout_algebra.md b/media/docs/cpp/cute/02_layout_algebra.md index 1b81e72f..8314495b 100644 --- a/media/docs/cpp/cute/02_layout_algebra.md +++ b/media/docs/cpp/cute/02_layout_algebra.md @@ -153,7 +153,6 @@ For example, To compute the strides of the strided layout, the residues of the above operation are used to scale the strides of `A`. For instance, the last example `(3,6,2,8):(w,x,y,z) / 72` with strides `(w,x,y,z)` produces `(3*w,6*x,2*x,2*z)` as the strides of the strided layout. - As you may have noticed, we can only divide shapes by certain values and get a sensible result. This is called the **stride divisibility condition** and is statically checked in CuTe when possible. 2. Keep the first `s` elements of the newly strided `A` so that the result has a compatible shape with `B`. This can be computed by "modding out" the first `s` elements from the shape of `A` starting from the left. @@ -175,11 +174,8 @@ Again, this operation must satisfy a **shape divisibility condition** to yield a From the above examples, we can construct the composition `(3,6,2,8):(w,x,y,z) o 16:9 = (1,2,2,4):(3*w,3*x,y,z)`. --- - #### Example 1 -- Worked Example of Calculating a Composition - We provide a more complex example of composition, where both operand layouts are multi-modal to illustrate the concepts introduced above. - ``` Functional composition, R := A o B R(c) := (A o B)(c) := A(B(c)) @@ -223,7 +219,6 @@ Putting this together and coalescing each mode, we obtain the result R = A o B = ((2, 2), 3): ((24, 2), 8) ``` - #### Example 2 -- Reshape a layout into a matrix `20:2 o (5,4):(4,1)`. Composition formulation. diff --git a/media/docs/cpp/cute/0z_tma_tensors.md b/media/docs/cpp/cute/0z_tma_tensors.md index 4b9c0070..d66fa1ac 100644 --- a/media/docs/cpp/cute/0z_tma_tensors.md +++ b/media/docs/cpp/cute/0z_tma_tensors.md @@ -138,13 +138,15 @@ In principle, layout strides may be any integer-module. CuTe's basis elements live in the header file `cute/numeric/arithmetic_tuple.hpp`. To make it easy to create `ArithmeticTuple`s that can be used as strides, CuTe defines normalized basis elements using the `E` type alias. "Normalized" means that the scaling factor of the basis element is the compile-time integer 1. -| C++ object | Description | String representation | -| --- | --- | --- | -| `E<>{}` | `1` | `1` | -| `E<0>{}` | `(1,0,...)` | `1@0` | -| `E<1>{}` | `(0,1,0,...)` | `1@1` | -| `E<0,1>{}` | `((0,1,0,...),0,...)` | `1@1@0` | -| `E<1,0>{}` | `(0,(1,0,...),0,...)` | `1@0@1` | +| C++ object | Description | String representation | +| --- | --- | --- | +| `E<>{}` | `1` | `1` | +| `E<0>{}` | `(1,0,...)` | `1@0` | +| `E<1>{}` | `(0,1,0,...)` | `1@1` | +| `E<0,0>{}` | `((1,0,...),0,...)` | `1@0@0` | +| `E<0,1>{}` | `((0,1,0,...),0,...)` | `1@1@0` | +| `E<1,0>{}` | `(0,(1,0,...),0,...)` | `1@0@1` | +| `E<1,1>{}` | `(0,(0,1,0,...),0,...)` | `1@1@1` | The "description" column in the above table interprets each basis element as an infinite tuple of integers, @@ -155,7 +157,9 @@ For example, `E<1>{}` has a 1 in position 1: `(0,1,0,...)`. Basis elements can be *nested*. For instance, in the above table, `E<0,1>{}` means that -in position 0 there is a `E<1>{}`: `((0,1,0,...),0,...)`. +in position 0 there is a `E<1>{}`: `((0,1,0,...),0,...)`. Similarly, +`1@1@0` means that `1` is lifted to position 1 to create `1@1`: `(0,1,0,...)` +which is then lifted again to position 0. Basis elements can be *scaled*. That is, they can be multiplied by an integer *scaling factor*. diff --git a/media/docs/cpp/getting_started.rst b/media/docs/cpp/getting_started.rst index df34f3f6..1dc5dc55 100644 --- a/media/docs/cpp/getting_started.rst +++ b/media/docs/cpp/getting_started.rst @@ -13,4 +13,5 @@ Getting Started Terminology Fundamental Types Programming Guidelines + GEMM Heuristics diff --git a/media/docs/cpp/heuristics.md b/media/docs/cpp/heuristics.md new file mode 100644 index 00000000..8b9166b1 --- /dev/null +++ b/media/docs/cpp/heuristics.md @@ -0,0 +1,102 @@ + +# GEMM Heuristics + +## Overview + +Gemm heuristics in `cutlass_library` aim to reduce the search space for runtime autotuning, so that only a subset of valid kernels need to be built and profiled for a given set of GEMM problems. This implementation uses Nvidia's `nvidia-matmul-heuristics`, an analytical heuristic that ranks GEMM kernels by estimated performance given a problem size and hardware SKU. You can find more info in [the docs](https://docs.nvidia.com/cuda/nvidia-matmul-heuristics). + +## Coverage + +Gemm heuristics in `cutlass_library` is an experimental feature and exhaustive functional or performance coverage is not guaranteed. It currently supports the following. + +Problem space: +- Plain dense gemm for `f8`, `f16`, `f32` + +Hardware: +- Hopper (sm9x) +- Blackwell (sm10x) + +## Usage / Quick Start + +### Install Dependencies + +Using the wheel is recommended: +``` +pip install nvidia-matmul-heuristics +``` + +### Prepare Input File + +Prepare a list of gemm problem definitions, in the form of a json list, to be evaluated by the heuristic. Here is a sample file with two problems: +``` +[ +{ + "m" : 4096, + "n" : 4096, + "k" : 4096, + "batch_count" : 1, + "layout" : "tnn", + "dtype_a" : "f16", + "dtype_b" : "f16", + "dtype_c" : "f16", + "dtype_acc" : "f32", + "dtype_d" : "f16", + "beta" : 0.0, + "use_fast_acc": false +}, +{ + "m" : 4096, + "n" : 4096, + "k" : 4096, + "batch_count" : 1, + "layout": "tnn", + "dtype_a" : "e5m2", + "dtype_b" : "e5m2", + "dtype_c" : "f32", + "dtype_acc" : "f32", + "dtype_d" : "e5m2", + "beta" : 0.0, + "use_fast_acc": true +} +] +``` + +Note: `use_fast_acc` only needs to be specified for FP8 kernels on SM90. Otherwise, it is ignored. + +### Build + +Build CUTLASS using CMake as normal, providing heuristics-specific options to CMake. Note that hardware details are detected automatically. For offline builds, use `-DCUTLASS_LIBRARY_HEURISTICS_GPU`. +For example, here is a minimal command for Nvidia's Hopper Architecture (sm90): + +``` +$ cmake .. \ + -DCUTLASS_NVCC_ARCHS=90a \ + -DCUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE= \ + -DCUTLASS_LIBRARY_HEURISTICS_CONFIGS_PER_PROBLEM= +... +... + +$ make cutlass_profiler -j + +``` + +This will produce a csv testlist which provides all testcases that need be run to perform autotuning over the built configurations, including kernel runtime options. The location of this file can be changed by the CMake option `-DCUTLASS_LIBRARY_HEURISTICS_TESTLIST_FILE`. + +CUTLASS CMake currently supports the following for heuristics: +- `CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE`: Path to the file containing a json list of GEMM problems +- `CUTLASS_LIBRARY_HEURISTICS_CONFIGS_PER_PROBLEM`: Max number of configurations the heuristic will return for each GEMM problem. The same configuration or kernel can be suggested for multiple problems. +- `CUTLASS_LIBRARY_HEURISTICS_RESTRICT_KERNELS`: Limits the build to only the set of kernels instantiated by the default CUTLASS CMake build flow, composing with other options such as `CUTLASS_LIBRARY_INSTANTIATION_LEVEL`. Set this to `ON` as a workaround if the heuristic suggests kernel configurations that do not build on your platform (possible for some unsupported or experimental use cases). This option is set to `OFF` by default, which builds all of the suggested configurations. +- `CUTLASS_LIBRARY_HEURISTICS_TESTLIST_FILE`: Path to the output CSV which will contain the testcases to be used for autotuning, consumable by `cutlass_profiler`. +- `CUTLASS_LIBRARY_HEURISTICS_GPU`: The GPU to use for heuristics; for instance, `H100_SXM5`. Used for offline builds. If unset, the hardware properties will be auto-detected using the Cuda Driver APIs. See `generator.py` for valid GPU strings + +### Profile + +Use the emitted testlist CSV with `cutlass_profiler` to collect performance data, which can be used to determine the fastest built kernel configuration for each of the input problems. Example which profiles each testcase for a fixed 50ms: +``` +cutlass_profiler --operation=Gemm --testlist-file= --profiling-iterations=0 --profiling-duration=50 --verification-enabled=false --output= +``` + +## Direct Usage in Python + +If you have pre-built CUTLASS kernels or custom CUTLASS emitters, you can use the Python APIs directly to select kernels to build or profile. See `filter_manifest_and_write_heuristics_file()` in `heuristics.py` for example usage. + diff --git a/media/docs/cpp/profiler.md b/media/docs/cpp/profiler.md index 22f88485..8331b75f 100644 --- a/media/docs/cpp/profiler.md +++ b/media/docs/cpp/profiler.md @@ -93,6 +93,12 @@ An instantiation level `500`, which is padded to `0500`, thus indicates: - **Cluster Sizes**: At level 5, allowing for clusters with 1, 2, 4, 8, or 16 CTAs. - **Schedule Pruning**: At level 0, where pruning is applied according to the existing `generator.py` behavior. +## Instantiating more MMA shapes with Hopper + +When instantiating more tile shapes, specially non-power-of-2 Tile-N shapes, make sure to enable `CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES`. +This may lead to some increase in per-kernel compilation times. +When `CUTLASS_LIBRARY_INSTANTIATION_LEVEL` is set, then `CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES` is enabled by default. + ## Mixed input data type kernels for Hopper With Hopper (SM90), the kernel generator will generate the following combinations of mixed input data types ("mixed dtype"): diff --git a/pyproject.toml b/pyproject.toml index 04571be9..c046dc94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nvidia-cutlass" -version = "3.9.0.0" +version = "4.0.0.0" description = "CUTLASS" readme = "README.md" requires-python = ">=3.8" diff --git a/python/README.md b/python/README.md index 3fb36f03..27003794 100644 --- a/python/README.md +++ b/python/README.md @@ -4,7 +4,7 @@ This directory contains Python packages that are associated with CUTLASS: -* `cutlass`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python +* `cutlass_cppgen`: the CUTLASS Python interface, which enables one to compile and run CUTLASS kernels from within Python. Note that this was previously named `cutlass`, but was renamed to disambiguate with the CuTe Python DSL. * `cutlass_library`: utilities used for enumerating and emitting C++ code for CUTLASS kernels ## CUTLASS Python Interface diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index 35cb4eba..dd0d7c62 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -119,8 +119,8 @@ def set_log_level(level: int): set_log_level(logging.ERROR) -from cutlass.library_defaults import OptionRegistry -from cutlass.backend.utils.device import device_cc +from cutlass_cppgen.library_defaults import OptionRegistry +from cutlass_cppgen.backend.utils.device import device_cc this._option_registry = None def get_option_registry(): @@ -135,14 +135,14 @@ def get_option_registry(): this.__version__ = '4.1.0' -from cutlass.backend import create_memory_pool -from cutlass.emit.pytorch import pytorch -from cutlass.op.gemm import Gemm -from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad -from cutlass.op.gemm_grouped import GroupedGemm -from cutlass.op.op import OperationBase -from cutlass.backend.evt.ir.tensor import Tensor -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.backend import create_memory_pool +from cutlass_cppgen.emit.pytorch import pytorch +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad +from cutlass_cppgen.op.gemm_grouped import GroupedGemm +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.backend.evt.ir.tensor import Tensor +from cutlass_cppgen.utils.lazy_import import lazy_import this.memory_pool = None diff --git a/python/cutlass/backend/__init__.py b/python/cutlass/backend/__init__.py index 1011cd22..59cfaf71 100644 --- a/python/cutlass/backend/__init__.py +++ b/python/cutlass/backend/__init__.py @@ -30,19 +30,19 @@ # ################################################################################################# -from cutlass.backend.arguments import * -from cutlass.backend.c_types import * -from cutlass.backend.compiler import ArtifactManager -from cutlass.backend.conv2d_operation import * -from cutlass.backend.epilogue import * -from cutlass.backend.frontend import * -from cutlass.backend.gemm_operation import * -from cutlass.backend.library import * -from cutlass.backend.memory_manager import PoolMemoryManager, create_memory_pool -from cutlass.backend.operation import * -from cutlass.backend.reduction_operation import * -from cutlass.backend.type_hint import * -from cutlass.backend.utils import * -from cutlass.backend.utils.device import device_cc +from cutlass_cppgen.backend.arguments import * +from cutlass_cppgen.backend.c_types import * +from cutlass_cppgen.backend.compiler import ArtifactManager +from cutlass_cppgen.backend.conv2d_operation import * +from cutlass_cppgen.backend.epilogue import * +from cutlass_cppgen.backend.frontend import * +from cutlass_cppgen.backend.gemm_operation import * +from cutlass_cppgen.backend.library import * +from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool +from cutlass_cppgen.backend.operation import * +from cutlass_cppgen.backend.reduction_operation import * +from cutlass_cppgen.backend.type_hint import * +from cutlass_cppgen.backend.utils import * +from cutlass_cppgen.backend.utils.device import device_cc compiler = ArtifactManager() diff --git a/python/cutlass/backend/arguments.py b/python/cutlass/backend/arguments.py index 7c2664e0..b1b0656a 100644 --- a/python/cutlass/backend/arguments.py +++ b/python/cutlass/backend/arguments.py @@ -33,16 +33,16 @@ from math import prod from typing import Union -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") import numpy as np -import cutlass -from cutlass.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend -from cutlass.backend.memory_manager import DevicePtrWrapper -from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor +import cutlass_cppgen +from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend +from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor class ArgumentBase: @@ -122,7 +122,7 @@ class ArgumentBase: Frees allocated device-side memory """ # Free any device memory allocated manually - if not cutlass.use_rmm: + if not cutlass_cppgen.use_rmm: for name, buf in self.buffers.items(): if isinstance(buf, DevicePtrWrapper): err, = cudart.cudaFree(buf.ptr) diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index 1e8f5774..91fbc23e 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -37,7 +37,7 @@ from cutlass_library import ( KernelScheduleType, TileSchedulerType ) -from cutlass.backend.library import DataTypeSizeBytes +from cutlass_cppgen.backend.library import DataTypeSizeBytes class GemmCoord_(ctypes.Structure): diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index b4715602..1b78b513 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -37,17 +37,17 @@ import sqlite3 import subprocess import tempfile -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") nvrtc = lazy_import("cuda.nvrtc") from cutlass_library import SubstituteTemplate -import cutlass -from cutlass import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger -from cutlass.backend.gemm_operation import GemmOperationUniversal -from cutlass.backend.library import ApiVersion -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger +from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal +from cutlass_cppgen.backend.library import ApiVersion +from cutlass_cppgen.backend.utils.device import device_cc IncludeTemplate = r"""#include "${include}" """ @@ -93,7 +93,7 @@ class CompilationOptions: opts.append(f"--include-path={incl}") arch_flag = f"-arch=sm_{self.arch}" - if self.arch == 90 and int(cutlass.nvcc_version().split('.')[0]) >= 12: + if self.arch == 90 and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: arch_flag += "a" opts.append(arch_flag) @@ -366,7 +366,7 @@ class ArtifactManager: CUTLASS_PATH + "/python/cutlass/cpp/include", ] - cutlass.initialize_cuda_context() + cutlass_cppgen.initialize_cuda_context() arch = device_cc() host_compile_options = CompilationOptions( diff --git a/python/cutlass/backend/conv2d_operation.py b/python/cutlass/backend/conv2d_operation.py index a261ce90..03679c43 100644 --- a/python/cutlass/backend/conv2d_operation.py +++ b/python/cutlass/backend/conv2d_operation.py @@ -34,7 +34,7 @@ from __future__ import annotations import ctypes from typing import Union -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") from cutlass_library import SubstituteTemplate import numpy as np @@ -65,17 +65,17 @@ from cutlass_library import ( get_complex_from_real, ) -from cutlass.backend.arguments import ArgumentBase -from cutlass.backend.c_types import dim3_, get_conv2d_arguments -from cutlass.backend.library import ( +from cutlass_cppgen.backend.arguments import ArgumentBase +from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments +from cutlass_cppgen.backend.library import ( EmissionType, TensorDescription, TileDescription, ) -from cutlass.backend.memory_manager import device_mem_alloc -from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass.backend.utils.device import to_device_ptr -from cutlass.shape import GemmCoord +from cutlass_cppgen.backend.memory_manager import device_mem_alloc +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.backend.utils.device import to_device_ptr +from cutlass_cppgen.shape import GemmCoord class Conv2dArguments(ArgumentBase): @@ -84,9 +84,9 @@ class Conv2dArguments(ArgumentBase): user-provide tensors into the kernel's argument. :param operation: the Conv2d operation to take the argument - :type operation: :class:`cutlass.backend.Conv2dOperation` + :type operation: :class:`cutlass_cppgen.backend.Conv2dOperation` :param problem_size: the Conv2d problem size - :type problem_size: :class:`cutlass.shape.Conv2dProblemSize` + :type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray :param B: tensor B @@ -98,7 +98,7 @@ class Conv2dArguments(ArgumentBase): :param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial :type split_k_mode: cutlass_library.library.SplitKMode, optional :param output_op: output operator, optional - :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) :type stream: :class:`cuda.cuda.CUstream` """ @@ -380,19 +380,19 @@ class Conv2dOperation: :type arch: int :param tile_description: tile description - :type tile_description: :class:`cutlass.backend.TileDescription` + :type tile_description: :class:`cutlass_cppgen.backend.TileDescription` :param A: tensor A description - :type A: :class:`cutlass.backend.TensorDescription` + :type A: :class:`cutlass_cppgen.backend.TensorDescription` :param B: tensor B description - :type B: :class:`cutlass.backend.TensorDescription` + :type B: :class:`cutlass_cppgen.backend.TensorDescription` :param C: tensor C description - :type C: :class:`cutlass.backend.TensorDescription` + :type C: :class:`cutlass_cppgen.backend.TensorDescription` :param D: tensor D description - :type D: :class:`cutlass.backend.TensorDescription` + :type D: :class:`cutlass_cppgen.backend.TensorDescription` :param element_epilogue: element type for computation in epilogue \ :type element_epilogue: cutlass_library.library.DataType @@ -444,7 +444,7 @@ class Conv2dOperation: Launch the cuda kernel with input arguments :param arguments: conv2d arguments - :type arguments: :class:`cutlass.backend.Conv2dArguments` + :type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments` """ # launch the kernel diff --git a/python/cutlass/backend/epilogue.py b/python/cutlass/backend/epilogue.py index d7f2bd62..49ad79c9 100644 --- a/python/cutlass/backend/epilogue.py +++ b/python/cutlass/backend/epilogue.py @@ -36,10 +36,10 @@ from cutlass_library import SubstituteTemplate import numpy as np from cutlass_library import DataType, DataTypeTag -from cutlass.backend.c_types import MatrixCoord_, tuple_factory -from cutlass.backend.frontend import NumpyFrontend -from cutlass.backend.library import ActivationOp, ActivationOpTag -from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor +from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory +from cutlass_cppgen.backend.frontend import NumpyFrontend +from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag +from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor dtype2ctype = { DataType.f16: ctypes.c_uint16, diff --git a/python/cutlass/backend/evt/__init__.py b/python/cutlass/backend/evt/__init__.py index 35ce4aa3..b61e983a 100644 --- a/python/cutlass/backend/evt/__init__.py +++ b/python/cutlass/backend/evt/__init__.py @@ -30,5 +30,5 @@ # ################################################################################################# -from cutlass.backend.evt.epilogue import EpilogueFunctorVisitor -from cutlass.backend.evt.frontend import PythonASTFrontend +from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend diff --git a/python/cutlass/backend/evt/backend/__init__.py b/python/cutlass/backend/evt/backend/__init__.py index bb7c0834..a1654548 100644 --- a/python/cutlass/backend/evt/backend/__init__.py +++ b/python/cutlass/backend/evt/backend/__init__.py @@ -30,7 +30,7 @@ # ################################################################################################# -from cutlass.backend.evt.backend.sm80_emitter import Sm80Emitter -import cutlass.backend.evt.backend.sm80_nodes as sm80_nodes -from cutlass.backend.evt.backend.sm90_emitter import Sm90Emitter -import cutlass.backend.evt.backend.sm90_nodes as sm90_nodes +from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter +import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes +from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter +import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes diff --git a/python/cutlass/backend/evt/backend/emitter_base.py b/python/cutlass/backend/evt/backend/emitter_base.py index 738dcf46..39723844 100644 --- a/python/cutlass/backend/evt/backend/emitter_base.py +++ b/python/cutlass/backend/evt/backend/emitter_base.py @@ -35,7 +35,7 @@ Base class for Epilogue Visitor Emitter """ from cutlass_library import DataTypeTag -from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR class FusionCallbacks: diff --git a/python/cutlass/backend/evt/backend/sm80_emitter.py b/python/cutlass/backend/evt/backend/sm80_emitter.py index a22e3379..868453a7 100644 --- a/python/cutlass/backend/evt/backend/sm80_emitter.py +++ b/python/cutlass/backend/evt/backend/sm80_emitter.py @@ -34,8 +34,8 @@ Emitter for Sm80 Epilogue Visitor """ -from cutlass.backend.evt.backend.emitter_base import FusionCallbacks -from cutlass.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend import GemmOperationUniversal class Sm80Emitter: diff --git a/python/cutlass/backend/evt/backend/sm80_nodes.py b/python/cutlass/backend/evt/backend/sm80_nodes.py index aafc38e2..b9fc5613 100644 --- a/python/cutlass/backend/evt/backend/sm80_nodes.py +++ b/python/cutlass/backend/evt/backend/sm80_nodes.py @@ -32,7 +32,7 @@ from cutlass_library import DataTypeSize, DataTypeTag -from cutlass.backend.evt.ir import ( +from cutlass_cppgen.backend.evt.ir import ( # Load Node AccumulatorImpl, AuxLoadImpl, @@ -50,7 +50,7 @@ from cutlass.backend.evt.ir import ( ScalarReductionImpl ) -from cutlass.backend.library import ( +from cutlass_cppgen.backend.library import ( FloatRoundStyleTag, FunctionalOp, op_tag, diff --git a/python/cutlass/backend/evt/backend/sm90_emitter.py b/python/cutlass/backend/evt/backend/sm90_emitter.py index 3d5b5046..3c058aa8 100644 --- a/python/cutlass/backend/evt/backend/sm90_emitter.py +++ b/python/cutlass/backend/evt/backend/sm90_emitter.py @@ -35,8 +35,8 @@ Emitter for Sm90 Epilogue Visitor """ from cutlass_library import DataTypeTag, EpilogueScheduleTag -from cutlass.backend import GemmOperationUniversal -from cutlass.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks class CollectiveEpilogue: diff --git a/python/cutlass/backend/evt/backend/sm90_nodes.py b/python/cutlass/backend/evt/backend/sm90_nodes.py index 62ad5004..43601a42 100644 --- a/python/cutlass/backend/evt/backend/sm90_nodes.py +++ b/python/cutlass/backend/evt/backend/sm90_nodes.py @@ -33,7 +33,7 @@ from pycute import product from cutlass_library import DataTypeSize, DataTypeTag -from cutlass.backend.evt.ir import ( +from cutlass_cppgen.backend.evt.ir import ( # Load Node AccumulatorImpl, AuxLoadImpl, @@ -53,7 +53,7 @@ from cutlass.backend.evt.ir import ( StoreNode, StoreDImpl, ) -from cutlass.backend.library import ( +from cutlass_cppgen.backend.library import ( FloatRoundStyleTag, FunctionalOp, op_tag, diff --git a/python/cutlass/backend/evt/epilogue.py b/python/cutlass/backend/evt/epilogue.py index 58bd5769..92f71a10 100644 --- a/python/cutlass/backend/evt/epilogue.py +++ b/python/cutlass/backend/evt/epilogue.py @@ -36,16 +36,16 @@ Epilogue Visitor interface for compiling, and running visitor-based epilogue. import ctypes -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") from cutlass_library import DataType import numpy as np -from cutlass.backend.epilogue import EpilogueFunctorBase -import cutlass.backend.evt.backend -from cutlass.backend.frontend import TensorFrontend -from cutlass.utils.datatypes import is_numpy_tensor -from cutlass.backend.evt.passes.util import cc_map +from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase +import cutlass_cppgen.backend.evt.backend +from cutlass_cppgen.backend.frontend import TensorFrontend +from cutlass_cppgen.utils.datatypes import is_numpy_tensor +from cutlass_cppgen.backend.evt.passes.util import cc_map class EpilogueFunctorVisitor(EpilogueFunctorBase): @@ -58,7 +58,7 @@ class EpilogueFunctorVisitor(EpilogueFunctorBase): """ def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None: # Type of Emitter based on CC - self.emit_cls = getattr(cutlass.backend.evt.backend, f"Sm{cc_map[cc]}Emitter") + self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter") # Visitor Types self.visitor = visitor diff --git a/python/cutlass/backend/evt/frontend/__init__.py b/python/cutlass/backend/evt/frontend/__init__.py index f2cd3c97..f2323278 100644 --- a/python/cutlass/backend/evt/frontend/__init__.py +++ b/python/cutlass/backend/evt/frontend/__init__.py @@ -30,4 +30,4 @@ # ################################################################################################# -from cutlass.backend.evt.frontend.python_ast import PythonASTFrontend +from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend diff --git a/python/cutlass/backend/evt/frontend/frontend_base.py b/python/cutlass/backend/evt/frontend/frontend_base.py index 06d3477b..c150bf20 100644 --- a/python/cutlass/backend/evt/frontend/frontend_base.py +++ b/python/cutlass/backend/evt/frontend/frontend_base.py @@ -37,14 +37,14 @@ Base class for Python EVT Frontend from typing import Union from cutlass_library import DataType -from cutlass.backend.evt.ir import ( +from cutlass_cppgen.backend.evt.ir import ( ComputeNode, DAGIR, LayoutNode, LoadNode, StoreNode, ) -from cutlass.backend.evt.passes import ( +from cutlass_cppgen.backend.evt.passes import ( EVTGraphDrawer, EVTPassManager, GetSmemSize, @@ -56,9 +56,9 @@ from cutlass.backend.evt.passes import ( PassPreprocessRed, PassShapeTypePropagation, ) -from cutlass.backend.utils import device_cc -from cutlass.epilogue.evt_ops import permute, reshape -from cutlass.utils.datatypes import library_type +from cutlass_cppgen.backend.utils import device_cc +from cutlass_cppgen.epilogue.evt_ops import permute, reshape +from cutlass_cppgen.utils.datatypes import library_type class EVTFrontendBase: diff --git a/python/cutlass/backend/evt/frontend/python_ast.py b/python/cutlass/backend/evt/frontend/python_ast.py index 258e9421..8727b754 100644 --- a/python/cutlass/backend/evt/frontend/python_ast.py +++ b/python/cutlass/backend/evt/frontend/python_ast.py @@ -40,10 +40,10 @@ import textwrap from cutlass_library import DataType -import cutlass -from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase -from cutlass.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu -from cutlass.backend.library import FunctionalOp +import cutlass_cppgen +from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase +from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu +from cutlass_cppgen.backend.library import FunctionalOp class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor): diff --git a/python/cutlass/backend/evt/ir/__init__.py b/python/cutlass/backend/evt/ir/__init__.py index 5d55adea..0f9e3f81 100644 --- a/python/cutlass/backend/evt/ir/__init__.py +++ b/python/cutlass/backend/evt/ir/__init__.py @@ -30,10 +30,10 @@ # ################################################################################################# -from cutlass.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl -from cutlass.backend.evt.ir.dag_ir import DAGIR -from cutlass.backend.evt.ir.layout_nodes import LayoutNode -from cutlass.backend.evt.ir.load_nodes import ( +from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl +from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR +from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode +from cutlass_cppgen.backend.evt.ir.load_nodes import ( LoadNode, AccumulatorImpl, LoadSrcImpl, @@ -42,8 +42,8 @@ from cutlass.backend.evt.ir.load_nodes import ( ColumnBroadcastImpl, ScalarBroadcastImpl ) -from cutlass.backend.evt.ir.node import TopoVisitorNode, NoOpImpl -from cutlass.backend.evt.ir.store_nodes import ( +from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl +from cutlass_cppgen.backend.evt.ir.store_nodes import ( StoreNode, StoreDImpl, AuxStoreImpl, diff --git a/python/cutlass/backend/evt/ir/compute_nodes.py b/python/cutlass/backend/evt/ir/compute_nodes.py index 6c9f51b2..02b05358 100644 --- a/python/cutlass/backend/evt/ir/compute_nodes.py +++ b/python/cutlass/backend/evt/ir/compute_nodes.py @@ -34,8 +34,8 @@ Python registration for compute nodes in EVT """ -from cutlass.backend.evt.ir.node import NodeBase, ImplBase -from cutlass.backend.library import FloatRoundStyle +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase +from cutlass_cppgen.backend.library import FloatRoundStyle class ComputeImplBase(ImplBase): diff --git a/python/cutlass/backend/evt/ir/dag_ir.py b/python/cutlass/backend/evt/ir/dag_ir.py index bb9121a7..e7e9f75a 100644 --- a/python/cutlass/backend/evt/ir/dag_ir.py +++ b/python/cutlass/backend/evt/ir/dag_ir.py @@ -38,10 +38,10 @@ import networkx as nx from cutlass_library import DataType -from cutlass.backend.evt.ir.compute_nodes import ComputeNode -from cutlass.backend.evt.ir.node import NodeBase -from cutlass.backend.library import ActivationOp -from cutlass.backend.utils import device_cc +from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.library import ActivationOp +from cutlass_cppgen.backend.utils import device_cc class DAGIR: diff --git a/python/cutlass/backend/evt/ir/layout_nodes.py b/python/cutlass/backend/evt/ir/layout_nodes.py index 81ddf094..1095e2ab 100644 --- a/python/cutlass/backend/evt/ir/layout_nodes.py +++ b/python/cutlass/backend/evt/ir/layout_nodes.py @@ -41,10 +41,10 @@ from copy import deepcopy from cutlass_library import LayoutType from pycute import product, flatten -import cutlass -from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list -from cutlass.backend.evt.ir.node import NodeBase -from cutlass.backend.evt.ir.tensor import Tensor +import cutlass_cppgen +from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.evt.ir.tensor import Tensor class PermutationImpl: diff --git a/python/cutlass/backend/evt/ir/load_nodes.py b/python/cutlass/backend/evt/ir/load_nodes.py index 73bf9825..bff0aaa2 100644 --- a/python/cutlass/backend/evt/ir/load_nodes.py +++ b/python/cutlass/backend/evt/ir/load_nodes.py @@ -36,9 +36,9 @@ Load nodes and implementations import ctypes -from cutlass.backend.c_types import tuple_factory -from cutlass.backend.epilogue import dtype2ctype, to_ctype_value -from cutlass.backend.evt.ir.node import NodeBase, ImplBase +from cutlass_cppgen.backend.c_types import tuple_factory +from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase class LoadImplBase(ImplBase): diff --git a/python/cutlass/backend/evt/ir/node.py b/python/cutlass/backend/evt/ir/node.py index b5d4fdd1..e2b3a34a 100644 --- a/python/cutlass/backend/evt/ir/node.py +++ b/python/cutlass/backend/evt/ir/node.py @@ -39,8 +39,8 @@ from re import sub from cutlass_library import LayoutType -from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple -from cutlass.backend.evt.ir.tensor import Tensor +from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple +from cutlass_cppgen.backend.evt.ir.tensor import Tensor class ImplBase: @@ -170,7 +170,7 @@ class NodeBase: @property def tensor(self) -> Tensor: """ - Return the output tensor (concept: cutlass.backend.evt.ir.tensor) + Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) """ return self._tensor diff --git a/python/cutlass/backend/evt/ir/store_nodes.py b/python/cutlass/backend/evt/ir/store_nodes.py index a3f06645..708405e0 100644 --- a/python/cutlass/backend/evt/ir/store_nodes.py +++ b/python/cutlass/backend/evt/ir/store_nodes.py @@ -38,11 +38,11 @@ import ctypes from cutlass_library import DataType -from cutlass.backend.c_types import tuple_factory -from cutlass.backend.epilogue import dtype2ctype, to_ctype_value -from cutlass.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl -from cutlass.backend.evt.ir.tensor import Tensor -from cutlass.backend.library import FloatRoundStyle, FunctionalOp +from cutlass_cppgen.backend.c_types import tuple_factory +from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl +from cutlass_cppgen.backend.evt.ir.tensor import Tensor +from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp class StoreImplBase(ImplBase): @@ -249,7 +249,7 @@ class StoreNode(NodeBase): @property def store_tensor(self) -> Tensor: """ - Return the output tensor (concept: cutlass.backend.evt.ir.tensor) + Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) """ return self._store_tensor diff --git a/python/cutlass/backend/evt/ir/tensor.py b/python/cutlass/backend/evt/ir/tensor.py index 9eea9f42..1a28b730 100644 --- a/python/cutlass/backend/evt/ir/tensor.py +++ b/python/cutlass/backend/evt/ir/tensor.py @@ -36,7 +36,7 @@ High-level class for tensor from cutlass_library import LayoutType -from cutlass.backend.evt.ir.layout_algorithm import ( +from cutlass_cppgen.backend.evt.ir.layout_algorithm import ( Layout, broadcast, canonicalization, @@ -44,7 +44,7 @@ from cutlass.backend.evt.ir.layout_algorithm import ( reshape, _reverse_tuple ) -from cutlass.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type +from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type class Tensor: diff --git a/python/cutlass/backend/evt/passes/__init__.py b/python/cutlass/backend/evt/passes/__init__.py index c2998397..badc38d9 100644 --- a/python/cutlass/backend/evt/passes/__init__.py +++ b/python/cutlass/backend/evt/passes/__init__.py @@ -30,13 +30,13 @@ # ################################################################################################# -from cutlass.backend.evt.passes.graph_drawer import EVTGraphDrawer -from cutlass.backend.evt.passes.pass_argument_type import PassGetArgumentType -from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree -from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl -from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD -from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination -from cutlass.backend.evt.passes.pass_manager import EVTPassManager -from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed -from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation -from cutlass.backend.evt.passes.smem_size_calculator import GetSmemSize +from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer +from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType +from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD +from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager +from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize diff --git a/python/cutlass/backend/evt/passes/graph_drawer.py b/python/cutlass/backend/evt/passes/graph_drawer.py index fd05bd92..8a28c6e4 100644 --- a/python/cutlass/backend/evt/passes/graph_drawer.py +++ b/python/cutlass/backend/evt/passes/graph_drawer.py @@ -35,7 +35,7 @@ import subprocess from cutlass_library import DataTypeTag -from cutlass.backend.evt.ir.dag_ir import DAGIR +from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR _COLOR_MAP = { diff --git a/python/cutlass/backend/evt/passes/pass_argument_type.py b/python/cutlass/backend/evt/passes/pass_argument_type.py index 0c5cc1d2..c458f799 100644 --- a/python/cutlass/backend/evt/passes/pass_argument_type.py +++ b/python/cutlass/backend/evt/passes/pass_argument_type.py @@ -34,12 +34,12 @@ Construct the epilogue visitor argument type """ -from cutlass.backend.c_types import visitor_factory -from cutlass.backend.evt.ir import TopoVisitorNode -from cutlass.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree -from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl -from cutlass.backend.evt.passes.pass_manager import EVTPassBase -from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.c_types import visitor_factory +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode +from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation class PassGetArgumentType(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py index 9fad1de3..5eae2f92 100644 --- a/python/cutlass/backend/evt/passes/pass_dag_2_tree.py +++ b/python/cutlass/backend/evt/passes/pass_dag_2_tree.py @@ -37,10 +37,10 @@ by the topological visitor, while the rest of the graph will be implemented with from copy import deepcopy -from cutlass.backend.evt.ir import DAGIR, TopoVisitorNode -from cutlass.backend.evt.passes.pass_get_impl import PassGetImpl -from cutlass.backend.evt.passes.pass_manager import EVTPassBase -from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation class PassDAG2Tree(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/pass_fix_element_d.py b/python/cutlass/backend/evt/passes/pass_fix_element_d.py index 3ef697ca..0d57c5b7 100644 --- a/python/cutlass/backend/evt/passes/pass_fix_element_d.py +++ b/python/cutlass/backend/evt/passes/pass_fix_element_d.py @@ -37,8 +37,8 @@ In Sm90 epilogue visitor, the node writing D to gmem does not have internal element converter, so the compute node producing D must have element_output = type(D). """ -from cutlass.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination -from cutlass.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase class PassFixElementD(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/pass_get_impl.py b/python/cutlass/backend/evt/passes/pass_get_impl.py index a883e9ff..90fdafe7 100644 --- a/python/cutlass/backend/evt/passes/pass_get_impl.py +++ b/python/cutlass/backend/evt/passes/pass_get_impl.py @@ -39,13 +39,13 @@ on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadca This pass infers the underlying impl of each node """ -import cutlass.backend.evt.backend as evt_backend -from cutlass.backend.evt.ir import DAGIR, LoadNode -from cutlass.backend.evt.passes.pass_fix_element_d import PassFixElementD -from cutlass.backend.evt.passes.pass_manager import EVTPassBase -from cutlass.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination -from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation -from cutlass.backend.evt.passes.util import cc_map +import cutlass_cppgen.backend.evt.backend as evt_backend +from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode +from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.util import cc_map class PassGetImpl(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/pass_layout_elimination.py b/python/cutlass/backend/evt/passes/pass_layout_elimination.py index 48c5d295..af147969 100644 --- a/python/cutlass/backend/evt/passes/pass_layout_elimination.py +++ b/python/cutlass/backend/evt/passes/pass_layout_elimination.py @@ -36,9 +36,9 @@ Eliminate layout manipulation nodes from copy import deepcopy -from cutlass.backend.evt.ir import DAGIR, LayoutNode -from cutlass.backend.evt.passes.pass_manager import EVTPassBase -from cutlass.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation class PassLayoutManipulateElimination(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/pass_manager.py b/python/cutlass/backend/evt/passes/pass_manager.py index e5b94048..e8b46bdd 100644 --- a/python/cutlass/backend/evt/passes/pass_manager.py +++ b/python/cutlass/backend/evt/passes/pass_manager.py @@ -38,8 +38,8 @@ from typing import Any import networkx as nx -from cutlass.backend.evt.ir import DAGIR -from cutlass.backend.evt.passes.util import cc_map +from cutlass_cppgen.backend.evt.ir import DAGIR +from cutlass_cppgen.backend.evt.passes.util import cc_map class EVTPassBase: diff --git a/python/cutlass/backend/evt/passes/pass_no_op_elimination.py b/python/cutlass/backend/evt/passes/pass_no_op_elimination.py index 148f87f8..13107eb1 100644 --- a/python/cutlass/backend/evt/passes/pass_no_op_elimination.py +++ b/python/cutlass/backend/evt/passes/pass_no_op_elimination.py @@ -36,8 +36,8 @@ No op elimination node from typing import Any -from cutlass.backend.evt.ir import NoOpImpl -from cutlass.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.ir import NoOpImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase class PassNoOpElimination(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/pass_preprocess_red.py b/python/cutlass/backend/evt/passes/pass_preprocess_red.py index 9a342636..6423a2b8 100644 --- a/python/cutlass/backend/evt/passes/pass_preprocess_red.py +++ b/python/cutlass/backend/evt/passes/pass_preprocess_red.py @@ -38,8 +38,8 @@ This pass fuses these into a single store node, and then replaces all uses of th current node with the new store node. """ -from cutlass.backend.evt.ir import ComputeNode, StoreNode -from cutlass.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase class PassPreprocessRed(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/pass_shape_type_propagation.py b/python/cutlass/backend/evt/passes/pass_shape_type_propagation.py index fc493626..cb90a82c 100644 --- a/python/cutlass/backend/evt/passes/pass_shape_type_propagation.py +++ b/python/cutlass/backend/evt/passes/pass_shape_type_propagation.py @@ -34,9 +34,9 @@ Shape and type propagation pass """ -from cutlass.backend.evt.ir.node import NodeBase -from cutlass.backend.evt.passes.pass_manager import EVTPassBase -from cutlass.backend.evt.passes.pass_preprocess_red import PassPreprocessRed +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed class PassShapeTypePropagation(EVTPassBase): diff --git a/python/cutlass/backend/evt/passes/smem_size_calculator.py b/python/cutlass/backend/evt/passes/smem_size_calculator.py index d28bf3a0..4896840e 100644 --- a/python/cutlass/backend/evt/passes/smem_size_calculator.py +++ b/python/cutlass/backend/evt/passes/smem_size_calculator.py @@ -37,9 +37,9 @@ Compute the shared memory size in bytes import cutlass_library from pycute import shape_div, product -import cutlass -from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR -from cutlass.backend.library import DataTypeSize +import cutlass_cppgen +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR +from cutlass_cppgen.backend.library import DataTypeSize class GetSmemSize: diff --git a/python/cutlass/backend/frontend.py b/python/cutlass/backend/frontend.py index c1fb97c3..a959976b 100644 --- a/python/cutlass/backend/frontend.py +++ b/python/cutlass/backend/frontend.py @@ -31,12 +31,12 @@ ################################################################################################# from __future__ import annotations -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") import numpy as np -from cutlass.backend.memory_manager import device_mem_alloc, todevice -from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor +from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor class NumpyFrontend: diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index 0305abd5..cf6bcc18 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -35,7 +35,7 @@ import copy import ctypes import enum -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") from cutlass_library import SubstituteTemplate @@ -74,8 +74,8 @@ from cutlass_library import ( TileSchedulerType, get_complex_from_real ) -from cutlass.backend.arguments import ArgumentBase -from cutlass.backend.c_types import ( +from cutlass_cppgen.backend.arguments import ArgumentBase +from cutlass_cppgen.backend.c_types import ( GemmCoord_, GemmCoordBatched_, GenericMainloopArguments3x_, @@ -88,7 +88,7 @@ from cutlass.backend.c_types import ( get_mainloop_arguments_3x, get_tile_scheduler_arguments_3x, ) -from cutlass.backend.library import ( +from cutlass_cppgen.backend.library import ( ApiVersion, EmissionType, SchedulerMode, @@ -97,11 +97,11 @@ from cutlass.backend.library import ( TileDescription, api_version, ) -from cutlass.backend.memory_manager import device_mem_alloc, todevice -from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass.backend.type_hint import GemmOperation, Tensor -from cutlass.backend.utils.device import device_sm_count -from cutlass.shape import GemmCoord, MatrixCoord +from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor +from cutlass_cppgen.backend.utils.device import device_sm_count +from cutlass_cppgen.shape import GemmCoord, MatrixCoord ################################################################################ @@ -116,9 +116,9 @@ def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int: Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``. :param layout: layout of the tensor - :type layout: cutlass.shape.LayoutType + :type layout: cutlass_cppgen.shape.LayoutType :param shape: shape of the tensor - :type shape: cutlass.shape.MatrixCoord + :type shape: cutlass_cppgen.shape.MatrixCoord :return: leading dimension of the tensor :rtype: int @@ -144,11 +144,11 @@ class GemmArguments2x(ArgumentBase): user-provide tensors into the kernel's argument :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass.backend.GemmOperationUniversal` | - :class:`cutlass.backend.GemmOperationGrouped` + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass.shape.GemmCoord` + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -166,7 +166,7 @@ class GemmArguments2x(ArgumentBase): :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` :param output_op: output operator, optional - :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) :type stream: :class:`cuda.cuda.CUstream` @@ -371,11 +371,11 @@ class GemmArguments2xStreamK(GemmArguments2x): user-provide tensors into the kernel's argument :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass.backend.GemmOperationUniversal` | - :class:`cutlass.backend.GemmOperationGrouped` + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass.shape.GemmCoord` + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -393,7 +393,7 @@ class GemmArguments2xStreamK(GemmArguments2x): :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` :param output_op: output operator, optional - :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` """ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): @@ -483,11 +483,11 @@ class GemmArguments3x(GemmArguments2x): user-provide tensors into the kernel's argument :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass.backend.GemmOperationUniversal` | - :class:`cutlass.backend.GemmOperationGrouped` + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass.shape.GemmCoord` + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -505,7 +505,7 @@ class GemmArguments3x(GemmArguments2x): :type gemm_mode: GemmUniversalMode :param output_op: output operator, optional - :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` """ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): @@ -631,11 +631,11 @@ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMo or 3x arguments depending on the `arch` field specified in `operation`. :param operation: the GEMM operation to take the argument - :type operation: :class:`cutlass.backend.GemmOperationUniversal` | - :class:`cutlass.backend.GemmOperationGrouped` + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` :param problem_size: GEMM problem size gemm(M, N, K) - :type operation: :class:`cutlass.shape.GemmCoord` + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` :param A: tensor A :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray @@ -653,7 +653,7 @@ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMo :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` :param output_op: output operator, optional - :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` """ if operation.swizzling_functor == SwizzlingFunctor.StreamK: if operation.api == ApiVersion.v3x: @@ -670,10 +670,10 @@ class GemmGroupedArguments: user-provide tensors into the kernel's argument :param operation: the GEMM Grouped operation to take the argument - :type operation: :class:`cutlass.backend.GemmOperationGrouped` + :type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped` :param problem_size: list of GEMM problem size gemm(M, N, K) - :type operation: list[:class:`cutlass.shape.GemmCoord`] + :type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`] :param A: list of tensor A :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] @@ -688,7 +688,7 @@ class GemmGroupedArguments: :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] :param output_op: output operator, optional - :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) :type stream: :class:`cuda.cuda.CUstream` diff --git a/python/cutlass/backend/library.py b/python/cutlass/backend/library.py index 51fe1fe9..a8b113b4 100644 --- a/python/cutlass/backend/library.py +++ b/python/cutlass/backend/library.py @@ -417,7 +417,7 @@ def CalculateSmemUsagePerStage(operation): :param op: operation for which the maximum stages should be computed. If stages are set via the `op.tile_description.stages` parameter, this setting is ignored in the present calculation - :type op: cutlass.backend.Operation + :type op: cutlass_cppgen.backend.Operation :return: number of bytes of shared memory consumed by a single stage :rtype: int @@ -442,7 +442,7 @@ def CalculateSmemUsage(operation): :param op: operation for which the maximum stages should be computed. If stages are set via the `op.tile_description.stages` parameter, this setting is ignored in the present calculation - :type op: cutlass.backend.Operation + :type op: cutlass_cppgen.backend.Operation :return: int """ diff --git a/python/cutlass/backend/memory_manager.py b/python/cutlass/backend/memory_manager.py index 9d7daf17..30e6bb31 100644 --- a/python/cutlass/backend/memory_manager.py +++ b/python/cutlass/backend/memory_manager.py @@ -32,11 +32,11 @@ import numpy as np -import cutlass -from cutlass.utils.datatypes import is_numpy_tensor -from cutlass.utils.lazy_import import lazy_import +import cutlass_cppgen +from cutlass_cppgen.utils.datatypes import is_numpy_tensor +from cutlass_cppgen.utils.lazy_import import lazy_import -if cutlass.use_rmm: +if cutlass_cppgen.use_rmm: import rmm else: cudart = lazy_import("cuda.cudart") @@ -73,7 +73,7 @@ def _todevice(host_data): """ Helper for transferring host data to device memory """ - if cutlass.use_rmm: + if cutlass_cppgen.use_rmm: return rmm.DeviceBuffer.to_device(host_data.tobytes()) else: nbytes = len(host_data.tobytes()) @@ -100,7 +100,7 @@ def todevice(host_data, dtype=np.float32): def device_mem_alloc(size): - if cutlass.use_rmm: + if cutlass_cppgen.use_rmm: return rmm.DeviceBuffer(size=size) else: err, ptr = cudart.cudaMalloc(size) @@ -114,7 +114,7 @@ def align_size(size, alignment=256): def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34): - if cutlass.use_rmm: + if cutlass_cppgen.use_rmm: memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size) return memory_pool else: diff --git a/python/cutlass/backend/operation.py b/python/cutlass/backend/operation.py index 5b5400df..1f4b26ad 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass/backend/operation.py @@ -31,10 +31,10 @@ ################################################################################################# import ctypes -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") -from cutlass.backend.utils.device import device_cc +from cutlass_cppgen.backend.utils.device import device_cc _supports_cluster_launch = None diff --git a/python/cutlass/backend/reduction_operation.py b/python/cutlass/backend/reduction_operation.py index 559d51c3..535cea2c 100644 --- a/python/cutlass/backend/reduction_operation.py +++ b/python/cutlass/backend/reduction_operation.py @@ -34,7 +34,7 @@ from __future__ import annotations import ctypes from typing import Union -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") import numpy as np @@ -47,14 +47,14 @@ from cutlass_library import ( SubstituteTemplate ) -import cutlass -from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params -from cutlass.backend.frontend import NumpyFrontend, TorchFrontend -from cutlass.backend.library import TensorDescription -from cutlass.backend.memory_manager import DevicePtrWrapper -from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass.shape import MatrixCoord -from cutlass.utils.datatypes import is_numpy_tensor, is_torch_tensor +import cutlass_cppgen +from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params +from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend +from cutlass_cppgen.backend.library import TensorDescription +from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.shape import MatrixCoord +from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor class ReductionOperation: @@ -200,7 +200,7 @@ class ReductionArguments: Frees allocated device-side memory """ # Free any device memory allocated manually - if not cutlass.use_rmm: + if not cutlass_cppgen.use_rmm: for attr in ["destination_buffer", "source_buffer"]: if hasattr(self, attr): buf = getattr(self, attr) diff --git a/python/cutlass/backend/utils/__init__.py b/python/cutlass/backend/utils/__init__.py index 638a97b1..0bae3bac 100644 --- a/python/cutlass/backend/utils/__init__.py +++ b/python/cutlass/backend/utils/__init__.py @@ -30,4 +30,4 @@ # ################################################################################ -from cutlass.backend.utils.device import check_cuda_errors, device_cc +from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc diff --git a/python/cutlass/backend/utils/device.py b/python/cutlass/backend/utils/device.py index 4f7620d9..9ed4096a 100644 --- a/python/cutlass/backend/utils/device.py +++ b/python/cutlass/backend/utils/device.py @@ -35,12 +35,12 @@ Utility functions for interacting with the device """ from __future__ import annotations -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") -import cutlass -from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor +import cutlass_cppgen +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor def check_cuda_errors(result: list): @@ -77,7 +77,7 @@ def device_cc(device: int = -1) -> int: :rtype: int """ if device == -1: - device = cutlass.device_id() + device = cutlass_cppgen.device_id() deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) major = str(deviceProp.major) @@ -87,7 +87,7 @@ def device_cc(device: int = -1) -> int: def device_sm_count(device: int = -1): if device == -1: - device = cutlass.device_id() + device = cutlass_cppgen.device_id() err, device_sm_count = cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device ) diff --git a/python/cutlass/emit/__init__.py b/python/cutlass/emit/__init__.py index e1026558..8e4121b5 100644 --- a/python/cutlass/emit/__init__.py +++ b/python/cutlass/emit/__init__.py @@ -30,4 +30,4 @@ # ################################################################################################# -from cutlass.emit.pytorch import pytorch +from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/cutlass/emit/common.py b/python/cutlass/emit/common.py index 4d9b8763..58f94e15 100644 --- a/python/cutlass/emit/common.py +++ b/python/cutlass/emit/common.py @@ -34,10 +34,10 @@ Common utilities for emitting CUTLASS kernels """ -import cutlass +import cutlass_cppgen # Strings used for printing information about the generation of emitted scripts -_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass.__version__} Python interface (https://github.com/nvidia/cutlass/python)" +_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)" _CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR} diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index e759596c..86374b8b 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -39,9 +39,9 @@ Example usage with JIT compilation: .. highlight:: python .. code-block:: python - plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor) op = plan.construct() - mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True) + mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True) # Generate inputs for the GEMM A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] @@ -55,9 +55,9 @@ Example usage without JIT compilation: .. highlight:: python .. code-block:: python - plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) op = plan.construct() - cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output') + cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output') After this call, the directory ``output`` contains ``setup.py``, ``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from @@ -83,12 +83,12 @@ import os from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate -from cutlass import CUTLASS_PATH, logger, swizzle -from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal -from cutlass.backend.conv2d_operation import Conv2dOperation -from cutlass.backend.library import ApiVersion -from cutlass.emit import common -from cutlass.utils.datatypes import is_torch_available +from cutlass_cppgen import CUTLASS_PATH, logger, swizzle +from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal +from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation +from cutlass_cppgen.backend.library import ApiVersion +from cutlass_cppgen.emit import common +from cutlass_cppgen.utils.datatypes import is_torch_available if is_torch_available(): import torch diff --git a/python/cutlass/epilogue/__init__.py b/python/cutlass/epilogue/__init__.py index 43a0beb6..faf6896e 100644 --- a/python/cutlass/epilogue/__init__.py +++ b/python/cutlass/epilogue/__init__.py @@ -30,7 +30,7 @@ # ################################################################################################# -from cutlass.epilogue.epilogue import ( +from cutlass_cppgen.epilogue.epilogue import ( get_activations, get_activation_epilogue, gelu, @@ -44,7 +44,7 @@ from cutlass.epilogue.epilogue import ( trace ) -from cutlass.epilogue.evt_ops import ( +from cutlass_cppgen.epilogue.evt_ops import ( max, multiply_add, sum, diff --git a/python/cutlass/epilogue/epilogue.py b/python/cutlass/epilogue/epilogue.py index 76c75e20..16d1fec8 100644 --- a/python/cutlass/epilogue/epilogue.py +++ b/python/cutlass/epilogue/epilogue.py @@ -39,11 +39,11 @@ code like the following for GEMM: .. highlight:: python .. code-block:: python - plan = cutlass.op.Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) - plan.activation = cutlass.epilogue.relu + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.activation = cutlass_cppgen.epilogue.relu """ -from cutlass.backend import epilogue, device_cc +from cutlass_cppgen.backend import epilogue, device_cc gelu = epilogue.gelu @@ -111,7 +111,7 @@ def get_activation_epilogue( """ Frontend for EVT that generates epilogue functor through tracing the input function """ -from cutlass.backend.evt.frontend import PythonASTFrontend +from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend def trace(fn, example_tensors, **kwargs): @@ -124,7 +124,7 @@ def trace(fn, example_tensors, **kwargs): .. hightlight:: python .. code-block:: python - import cutlass.backend.evt + import cutlass_cppgen.backend.evt # Define epilogue function as Python callable def example_fn(accum, C, alpha, beta, gamma): @@ -142,7 +142,7 @@ def trace(fn, example_tensors, **kwargs): } # Generate the epilogue functor - epilogue_visitor = cutlass.epilogue.trace(example_fn, example_inputs) + epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs) """ if callable(fn): class EpilogueFunctor(PythonASTFrontend): diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass/epilogue/evt_ops.py index 0d2ef36e..7d8e2c01 100644 --- a/python/cutlass/epilogue/evt_ops.py +++ b/python/cutlass/epilogue/evt_ops.py @@ -36,7 +36,7 @@ Collection of builtin functions used for host reference in EVT import numpy as np -from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor if is_torch_available(): import torch diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 9c6f0b39..da321577 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -40,9 +40,9 @@ import logging import cutlass_library from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode -import cutlass -from cutlass.utils.check import valid_stage_count -from cutlass.utils.datatypes import td_from_profiler_td, td_from_profiler_op +import cutlass_cppgen +from cutlass_cppgen.utils.check import valid_stage_count +from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op _generator_ccs = [50, 60, 61, 70, 75, 80, 90] @@ -99,14 +99,14 @@ class KernelsForDataType: ops.extend(alignment_ops) return ops - def default_operation(self, math_operation: cutlass.MathOperation): + def default_operation(self, math_operation: cutlass_cppgen.MathOperation): key = sorted(list(self.kernels_by_alignment.keys()))[0] kernels = self.kernels_by_alignment[key] if math_operation is not None: kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation] return kernels[0] - def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass.MathOperation): + def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation): """ Returns operations satisfying the alignment constraints @@ -117,7 +117,7 @@ class KernelsForDataType: :param alignment_C: alignment constraint of operations to return :type alignment_C: int :param math_operation: math operation to consider - :type math_operation: cutlass.MathOperation + :type math_operation: cutlass_cppgen.MathOperation :return: list of operations :rtype: list @@ -158,14 +158,14 @@ class KernelsForDataType: return operand_list.index(key) - def find_alignment(self, shape: tuple, layout: cutlass.LayoutType, operand=str) -> int: + def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int: """ Returns the most preferable alignment for a given shape and layout :param shape: extent of each dimension of the tensor :type shape: tuple :param layout: layout of the tensor - :type layout: cutlass.LayoutType + :type layout: cutlass_cppgen.LayoutType :param operand: descriptor of the operand in question :type operand: str @@ -175,11 +175,11 @@ class KernelsForDataType: operand_idx = self._operand_idx(operand) # Determine the leading dimension of the shape - if layout == cutlass.LayoutType.ColumnMajor: + if layout == cutlass_cppgen.LayoutType.ColumnMajor: ld = shape[-2] - elif layout == cutlass.LayoutType.RowMajor: + elif layout == cutlass_cppgen.LayoutType.RowMajor: ld = shape[-1] - elif layout == cutlass.LayoutType.TensorNHWC: + elif layout == cutlass_cppgen.LayoutType.TensorNHWC: ld = shape[-1] else: raise Exception(f"Unexpected or unsupported layout {layout}") @@ -204,12 +204,12 @@ class KernelsForDataType: for alignment in self.kernels_by_alignment.keys(): self.kernels_by_alignment[alignment].sort(key=key, reverse=True) - def supports_math_operation(self, math_operation: cutlass.MathOperation) -> bool: + def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool: """ Returns whether `math_operation` is supported by at least one operation. :param math_operation: math operation to consider - :type math_operation: cutlass.MathOperation + :type math_operation: cutlass_cppgen.MathOperation :return: whether math_operation is supported by at least one operation :rtype: bool @@ -262,7 +262,7 @@ class ArchOptions: # descriptions for the target CC generate_function_name = "GenerateSM" + str(kernel_cc) if not hasattr(cutlass_library.generator, generate_function_name): - cutlass.logger.warning(f"No generator found for architecture {kernel_cc}") + cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}") return generate_function = getattr(cutlass_library.generator, generate_function_name) @@ -270,16 +270,16 @@ class ArchOptions: # for the target CC args = [ "--kernels=all", - f"--log-level={logging.getLevelName(cutlass.logger.level)}" + f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}" ] manifest_args = cutlass_library.generator.define_parser().parse_args(args) manifest = cutlass_library.manifest.Manifest(manifest_args) - generate_function(manifest, cutlass._nvcc_version) + generate_function(manifest, cutlass_cppgen._nvcc_version) if operation_kind not in manifest.operations: # No kernels generated for this architecture, this could be because the CUDA # toolkit is insufficient to support operations in this CC - cutlass.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}") + cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}") return # Only one CC should be returned, given the setup above of calling only the generation scripts @@ -358,7 +358,7 @@ class ArchOptions: # Add FP8 A/B with FP32 C for type_comb in combinations_with_replacement(fp8_types, 2): - types.append(type_comb + (cutlass.DataType.f32,)) + types.append(type_comb + (cutlass_cppgen.DataType.f32,)) layouts = [ (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor), @@ -444,7 +444,7 @@ class ArchOptions: :param layout_comb: tuple of data types for (layout_A, layout_B) :type layout_comb: tuple[cutlass_library.LayoutType] :param math_operation: math operation to consider or None if any can be considered - :type math_operation: cutlass.MathOperation + :type math_operation: cutlass_cppgen.MathOperation :return: set of operation classes that support the provided data type and layout combination :rtype: set @@ -484,7 +484,7 @@ class ArchOptions: :param layout_b: layout of operand B :type layout_b: cutlass_library.LayoutType :param math_operation: math operation to consider - :type math_operation: cutlass.MathOperation + :type math_operation: cutlass_cppgen.MathOperation :return: set of operation classes that support the provided data type combination :rtype: set @@ -524,7 +524,7 @@ class ArchOptions: :param layout_b: layout of operand B :type layout_b: cutlass_library.LayoutType :param math_operation: math operation to consider - :type math_operation: cutlass.MathOperation + :type math_operation: cutlass_cppgen.MathOperation :return: container of kernels by alignment supported by the provided combination of parameters :rtype: KernelsForDataType diff --git a/python/cutlass/op/__init__.py b/python/cutlass/op/__init__.py index 5332556c..02869070 100644 --- a/python/cutlass/op/__init__.py +++ b/python/cutlass/op/__init__.py @@ -30,7 +30,7 @@ # ################################################################################################# -from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad -from cutlass.op.gemm import Gemm -from cutlass.op.gemm_grouped import GroupedGemm -from cutlass.op.op import OperationBase +from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.op.gemm_grouped import GroupedGemm +from cutlass_cppgen.op.op import OperationBase diff --git a/python/cutlass/op/conv.py b/python/cutlass/op/conv.py index 3639d477..4f21d854 100644 --- a/python/cutlass/op/conv.py +++ b/python/cutlass/op/conv.py @@ -47,7 +47,7 @@ .. code-block:: python # A, B, C, and D are torch/numpy/cupy tensor objects - plan = cutlass.op.Conv(A, B, C, D) + plan = cutlass_cppgen.op.Conv(A, B, C, D) plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1)) One can also use the interface by specifying data types of operands at construction @@ -57,11 +57,11 @@ .. code-block:: python # The following is shorthand for: - # cutlass.op.Conv2d(kind="fprop", + # cutlass_cppgen.op.Conv2d(kind="fprop", # element_A=torch.float32, element_B=torch.float32, # element_C=torch.float32, element_D=torch.float32, # element_accumulator=torch.float32) - plan = cutlass.op.Conv2d(kind="fprop", element=torch.float32) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32) A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda') B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda') @@ -81,7 +81,7 @@ .. highlight:: python .. code-block:: python - plan = cutlass.op.Conv2d(kind="fprop", element=np.float32) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) # Do other work... @@ -96,15 +96,15 @@ .. highlight:: python .. code-block:: python - plan = cutlass.op.Conv2d(kind="fprop", element=np.float32) - plan.activation = cutlass.epilogue.relu + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + plan.activation = cutlass_cppgen.epilogue.relu Operations can also be run asynchronously: .. highlight:: python .. code-block:: python - plan = cutlass.op.Conv2d(kind="fprop", element=np.float32) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) args = plan.run() # Do other work... @@ -114,7 +114,7 @@ from __future__ import annotations from typing import Optional -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") from cutlass_library import ( @@ -127,15 +127,15 @@ from cutlass_library import ( StrideSupport, ) -import cutlass -from cutlass import epilogue -from cutlass.backend import compiler -from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation -from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments -from cutlass.backend.library import TensorDescription, TileDescription -from cutlass.op.op import OperationBase -from cutlass.shape import Conv2DProblemSize, MatrixCoord -from cutlass.utils import check, datatypes +import cutlass_cppgen +from cutlass_cppgen import epilogue +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation +from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord +from cutlass_cppgen.utils import check, datatypes class Conv2d(OperationBase): @@ -155,11 +155,11 @@ class Conv2d(OperationBase): # Use F32 for A, B, C, D, and accumulation in fprop # Use the generic ``element`` parameter to concisely set all data types for operands to the same values. - Conv2d(kind="fprop", element=cutlass.DataType.f32) + Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32) # Explicitly specify the data types to use for A, B, C, and D. - Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, - element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32) + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, + element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32) # Set the data types and elements from existing tensors. Note that one can use different tensors when # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must @@ -169,8 +169,8 @@ class Conv2d(OperationBase): # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit # those passed in via the generic ``element`` - Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, - element=cutlass.DataType.f32) + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + element=cutlass_cppgen.DataType.f32) The order of precedence for the setting of the data type for a given operand/output is as follows: 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor @@ -186,17 +186,17 @@ class Conv2d(OperationBase): :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B :param beta: scalar parameter beta from GEMM operation that scales operand C :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type - :type element: cutlass.DataType + :type element: cutlass_cppgen.DataType :param element_A: data type to be used for operand A - :type element_A: cutlass.DataType + :type element_A: cutlass_cppgen.DataType :param element_B: data type to be used for operand B - :type element_B: cutlass.DataType + :type element_B: cutlass_cppgen.DataType :param element_C: data type to be used for operand C - :type element_C: cutlass.DataType + :type element_C: cutlass_cppgen.DataType :param element_D: data type to be used for operand D - :type element_D: cutlass.DataType + :type element_D: cutlass_cppgen.DataType :param element_accumulator: data type to be used in accumulation of the product of operands A and B - :type element_accumulator: cutlass.DataType + :type element_accumulator: cutlass_cppgen.DataType :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 :type cc: int :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 @@ -215,7 +215,7 @@ class Conv2d(OperationBase): if self.current_cc == 90: # The Conv2d kernel on Hopper (SM90) is currently unsupported # Revert to use SM80-tagged kernels - cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") self.specified_kernel_cc = 80 self._reset_options(80) @@ -250,7 +250,7 @@ class Conv2d(OperationBase): assert elt_to_set is not None # Currently we only support layout TensorNHWC - lay_to_set = cutlass.LayoutType.TensorNHWC + lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC elements.append(datatypes.library_type(elt_to_set)) layouts.append(lay_to_set) @@ -301,10 +301,10 @@ class Conv2d(OperationBase): self._layout_a, self._layout_b, self._math_operation ) - if cutlass.OpcodeClass.TensorOp in self.possible_op_classes: - self.opclass = cutlass.OpcodeClass.TensorOp - elif cutlass.OpcodeClass.Simt in self.possible_op_classes: - self.opclass = cutlass.OpcodeClass.Simt + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt else: if self._math_operation is not None: math_op_str = f' and math operation {self._math_operation}' @@ -342,7 +342,7 @@ class Conv2d(OperationBase): Set the tile description :param td: tile description - :type td: cutlass.backend.TileDescription, or a dict with keys + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys { "threadblock_shape": [int, int, int], "warp_count": [int, int, int], @@ -359,7 +359,7 @@ class Conv2d(OperationBase): self._tile_description = datatypes.td_from_profiler_op(op) if "cluster_shape" in td.keys(): if td["cluster_shape"] != [1, 1, 1]: - cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.") + cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.") td["cluster_shape"] = [1, 1, 1] td = self._tile_description.clone_and_update(td) @@ -381,7 +381,7 @@ class Conv2d(OperationBase): - Is the kernel schedule being used supported on the architecture in question? :param td: tile description to validate - :type td: cutlass.backend.TileDescription + :type td: cutlass_cppgen.backend.TileDescription :return: tuple in which the first element is a bool indicating that the tile description is valid and the second element is a string providing an optional error message. :rtype: tuple @@ -445,9 +445,9 @@ class Conv2d(OperationBase): """ if self.conv_kind == ConvKind.Dgrad: if stride[0] != 1 or stride[1] != 1: - return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}") + return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}") - return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}") + return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}") # # Iterator Algorithm Related @@ -546,14 +546,14 @@ class Conv2d(OperationBase): self, tile_description: TileDescription = None, alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, iterator_algorithm: IteratorAlgorithm = None, - stride_support = None, swizzling_functor: cutlass.swizzle = None, - epilogue_functor=None) -> cutlass.backend.Conv2dOperation: + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation: """ - Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current + Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current kernel specification of the ``Conv2d`` object. :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass.backend.TileDescription + :type tile_description: cutlass_cppgen.backend.TileDescription :param alignment_A: alignment of operand A :type alignment_A: int :param alignment_B: alignment of operand B @@ -565,11 +565,11 @@ class Conv2d(OperationBase): :param stride_support: the stride support of dgrad :type stride_support: cutlass_library.library.StrideSupport :param swizzling_functor: the swizzling functor - :type swizzling_functor: cutlass.swizzle + :type swizzling_functor: cutlass_cppgen.swizzle :param epilogue_functor: the epilogue functor :return: operation that was constructed - :rtype: cutlass.backend.Conv2dOperation + :rtype: cutlass_cppgen.backend.Conv2dOperation """ # Get alignment alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A) @@ -637,8 +637,8 @@ class Conv2d(OperationBase): def compile(self, tile_description: TileDescription = None, alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, iterator_algorithm: IteratorAlgorithm = None, - stride_support = None, swizzling_functor: cutlass.swizzle = None, - epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation: + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation: """ Emits and compiles the kernel currently specified. If ``tile_description`` and any of the ``alignment`` parameters are set, the kernel will be chosen using this @@ -646,7 +646,7 @@ class Conv2d(OperationBase): will be used. ::param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass.backend.TileDescription + :type tile_description: cutlass_cppgen.backend.TileDescription :param alignment_A: alignment of operand A :type alignment_A: int :param alignment_B: alignment of operand B @@ -658,11 +658,11 @@ class Conv2d(OperationBase): :param stride_support: the stride support of dgrad :type stride_support: cutlass_library.library.StrideSupport :param swizzling_functor: the swizzling functor - :type swizzling_functor: cutlass.swizzle + :type swizzling_functor: cutlass_cppgen.swizzle :param epilogue_functor: the epilogue functor :return: operation that was compiled - :rtype: cutlass.backend.Conv2dOperation + :rtype: cutlass_cppgen.backend.Conv2dOperation """ self.operation = self.construct( @@ -770,7 +770,7 @@ class Conv2d(OperationBase): :type stream: :class:`cuda.cuda.CUstream` :return: arguments passed in to the kernel - :rtype: cutlass.backend.Conv2dArguments + :rtype: cutlass_cppgen.backend.Conv2dArguments """ if not stream: stream = cuda.CUstream(0) diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index 786c565b..fddd0c09 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -47,7 +47,7 @@ .. code-block:: python # A, B, C, and D are torch/numpy/cupy tensor objects - plan = cutlass.op.Gemm(A, B, C, D) + plan = cutlass_cppgen.op.Gemm(A, B, C, D) plan.run() @@ -58,11 +58,11 @@ .. code-block:: python # The following is shorthand for: - # cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32, + # cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32, # element_C=torch.float32, element_D=torch.float32, # element_accumulator=torch.float32, - # layout=cutlass.LayoutType.RowMajor) - plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) + # layout=cutlass_cppgen.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) A0 = torch.rand((128, 256), device='cuda') B0 = torch.rand((256, 64), device='cuda') @@ -82,7 +82,7 @@ .. highlight:: python .. code-block:: python - plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) plan.compile() # Do other work... @@ -98,15 +98,15 @@ .. highlight:: python .. code-block:: python - plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor) - plan.activation = cutlass.epilogue.relu + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.activation = cutlass_cppgen.epilogue.relu Operations can also be run asynchronously: .. highlight:: python .. code-block:: python - plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) args = plan.run() # Do other work... @@ -117,7 +117,7 @@ from __future__ import annotations from typing import Optional from math import prod -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") from cutlass_library import ( DataType, @@ -125,15 +125,15 @@ from cutlass_library import ( GemmUniversalMode, ) -import cutlass -from cutlass import epilogue, swizzle -from cutlass.backend import compiler -from cutlass.backend.evt import EpilogueFunctorVisitor -from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal -from cutlass.backend.library import TensorDescription, TileDescription -from cutlass.op.op import OperationBase -from cutlass.shape import GemmCoord -from cutlass.utils import check, datatypes +import cutlass_cppgen +from cutlass_cppgen import epilogue, swizzle +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes class Gemm(OperationBase): @@ -154,11 +154,11 @@ class Gemm(OperationBase): # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts # for operands to the same values. - Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) + Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``. - Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32, - element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) + Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32, + element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) # Set the data types and elements from existing tensors. Note that one can use different tensors when # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must @@ -168,13 +168,13 @@ class Gemm(OperationBase): # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is # the same as that for D, at present) - Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor, - layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor) + Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor, + layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor) # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types # and layouts will inherit those passed in via the generic ``element`` and ``layout`` - Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor, - element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor) + Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor, + element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) The order of precedence for the setting of the data type and layout for a given operand/output is as follows: 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor @@ -192,27 +192,27 @@ class Gemm(OperationBase): :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B :param beta: scalar parameter beta from GEMM operation that scales operand C :param element_accumulator: data type to be used in accumulation of the product of operands A and B - :type element_accumulator: cutlass.DataType + :type element_accumulator: cutlass_cppgen.DataType :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type - :type element: cutlass.DataType + :type element: cutlass_cppgen.DataType :param layout: generic layout type to be used for operands A, B, C, and D - :type layout: cutlass.LayoutType + :type layout: cutlass_cppgen.LayoutType :param element_A: data type to be used for operand A - :type element_A: cutlass.DataType + :type element_A: cutlass_cppgen.DataType :param element_B: data type to be used for operand B - :type element_B: cutlass.DataType + :type element_B: cutlass_cppgen.DataType :param element_C: data type to be used for operand C - :type element_C: cutlass.DataType + :type element_C: cutlass_cppgen.DataType :param element_D: data type to be used for operand D - :type element_D: cutlass.DataType + :type element_D: cutlass_cppgen.DataType :param layout_A: layout of operand A - :type layout_A: cutlass.LayoutType + :type layout_A: cutlass_cppgen.LayoutType :param layout_B: layout of operand B - :type layout_B: cutlass.LayoutType + :type layout_B: cutlass_cppgen.LayoutType :param layout_C: layout of operand C - :type layout_C: cutlass.LayoutType + :type layout_C: cutlass_cppgen.LayoutType :param layout_D: layout of operand D - :type layout_D: cutlass.LayoutType + :type layout_D: cutlass_cppgen.LayoutType """ def __init__( @@ -278,7 +278,7 @@ class Gemm(OperationBase): self._reset_operations() - self._swizzling_functor = cutlass.swizzle.IdentitySwizzle1 + self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1 def _reset_operations(self, reset_epilogue: bool = True): # Set the default op class @@ -289,10 +289,10 @@ class Gemm(OperationBase): self._element_a, self._element_b, self._element_accumulator, self._layout_a, self._layout_b, self._math_operation) - if cutlass.OpcodeClass.TensorOp in self.possible_op_classes: - self.opclass = cutlass.OpcodeClass.TensorOp - elif cutlass.OpcodeClass.Simt in self.possible_op_classes: - self.opclass = cutlass.OpcodeClass.Simt + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt else: if self._math_operation is not None: math_op_str = f' and math operation {self._math_operation}' @@ -303,7 +303,7 @@ class Gemm(OperationBase): f'combination {datatype_comb}x{layout_comb}{math_op_str}') if reset_epilogue: - self._reset_epilogue_functor_activation(cutlass.epilogue.identity) + self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity) @property def swizzling_functor(self): @@ -319,8 +319,8 @@ class Gemm(OperationBase): """ Sets the swizzling functor to the type specified by `swizzling_functor` """ - if swizzling_functor == cutlass.swizzle.ThreadblockSwizzleStreamK: - if self.op_class == cutlass.OpcodeClass.Simt: + if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK: + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') if self.current_cc == 90: @@ -345,7 +345,7 @@ class Gemm(OperationBase): Set the tile description :param td: tile description - :type td: cutlass.backend.TileDescription, or a dict with keys + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys { "threadblock_shape": [int, int, int], "warp_count": [int, int, int], @@ -380,7 +380,7 @@ class Gemm(OperationBase): - Is the kernel schedule being used supported on the architecture in question? :param td: tile description to validate - :type td: cutlass.backend.TileDescription + :type td: cutlass_cppgen.backend.TileDescription :return: tuple in which the first element is a bool indicating that the tile description is valid and the second element is a string providing an optional error message. :rtype: tuple @@ -412,11 +412,11 @@ class Gemm(OperationBase): self, tile_description: TileDescription = None, alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal: """ - Constructs a ``cutlass.backend.GemmUniversalOperation`` based on the input parameters and current + Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current kernel specification of the ``Gemm`` object. :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass.backend.TileDescription + :type tile_description: cutlass_cppgen.backend.TileDescription :param alignment_A: alignment of operand A :type alignment_A: int :param alignment_B: alignment of operand B @@ -425,7 +425,7 @@ class Gemm(OperationBase): :type alignment_C: int :return: operation that was constructed - :rtype: cutlass.backend.GemmOperationUniversal + :rtype: cutlass_cppgen.backend.GemmOperationUniversal """ alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) @@ -471,7 +471,7 @@ class Gemm(OperationBase): def compile(self, tile_description: TileDescription = None, alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, - print_module: bool = False) -> cutlass.backend.GemmOperationUniversal: + print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal: """ Emits and compiles the kernel currently specified. If ``tile_description`` and any of the ``alignment`` parameters are set, the kernel will be chosen using this @@ -479,7 +479,7 @@ class Gemm(OperationBase): will be used. :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass.backend.TileDescription + :type tile_description: cutlass_cppgen.backend.TileDescription :param alignment_A: alignment of operand A :type alignment_A: int :param alignment_B: alignment of operand B @@ -490,7 +490,7 @@ class Gemm(OperationBase): :type print_module: bool :return: operation that was compiled - :rtype: cutlass.backend.GemmOperationUniversal + :rtype: cutlass_cppgen.backend.GemmOperationUniversal """ self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C) @@ -566,7 +566,7 @@ class Gemm(OperationBase): :param D: tensor D :type D: numpy/cupy/torch array/tensor object - :return: tuple containing the problem size (cutlass.shape.GemmCoord), the GEMM mode (cutlass.GemmUniversalMode), and the batch count (int) + :return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int) :rtype: tuple """ M, K = A.shape[-2:] @@ -582,9 +582,9 @@ class Gemm(OperationBase): # and C are row major. A similar operation can be performed if only B has a nonzero # batch dimension if batch_count > 1: - A_row = self._layout_a == cutlass.LayoutType.RowMajor - B_row = self._layout_b == cutlass.LayoutType.RowMajor - C_row = self._layout_c == cutlass.LayoutType.RowMajor + A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor + B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor + C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor # Consider a Tensor to be batched if its rank is > 2 and # the product of the modes beyond rank 2 equals our pre-determined batch size. @@ -652,7 +652,7 @@ class Gemm(OperationBase): :type stream: :class:`cuda.cuda.CUstream` :return: arguments passed in to the kernel - :rtype: cutlass.backend.GemmArguments + :rtype: cutlass_cppgen.backend.GemmArguments """ if not stream: stream = cuda.CUstream(0) diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py index da2fc8b9..594106f2 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass/op/gemm_grouped.py @@ -47,27 +47,27 @@ .. code-block:: python # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects - plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) """ from __future__ import annotations from typing import Optional from cutlass_library import DataTypeSize -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") -from cutlass.backend.gemm_operation import ( +from cutlass_cppgen.backend.gemm_operation import ( GemmGroupedArguments, GemmOperationGrouped, ) -from cutlass.backend.library import ( +from cutlass_cppgen.backend.library import ( SchedulerMode, TensorDescription, TileDescription, ) -from cutlass.op.gemm import Gemm -from cutlass.shape import GemmCoord -from cutlass.utils import check, datatypes +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes class GroupedGemm(Gemm): @@ -90,27 +90,27 @@ class GroupedGemm(Gemm): :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B :param beta: scalar parameter beta from GEMM operation that scales operand C :param element_accumulator: data type to be used in accumulation of the product of operands A and B - :type element_accumulator: cutlass.DataType + :type element_accumulator: cutlass_cppgen.DataType :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type - :type element: cutlass.DataType + :type element: cutlass_cppgen.DataType :param layout: generic layout type to be used for operands A, B, C, and D - :type layout: cutlass.LayoutType + :type layout: cutlass_cppgen.LayoutType :param element_A: data type to be used for operand A - :type element_A: cutlass.DataType + :type element_A: cutlass_cppgen.DataType :param element_B: data type to be used for operand B - :type element_B: cutlass.DataType + :type element_B: cutlass_cppgen.DataType :param element_C: data type to be used for operand C - :type element_C: cutlass.DataType + :type element_C: cutlass_cppgen.DataType :param element_D: data type to be used for operand D - :type element_D: cutlass.DataType + :type element_D: cutlass_cppgen.DataType :type layout_A: layout of operand A - :param layout_A: cutlass.LayoutType + :param layout_A: cutlass_cppgen.LayoutType :type layout_B: layout of operand B - :param layout_B: cutlass.LayoutType + :param layout_B: cutlass_cppgen.LayoutType :type layout_C: layout of operand C - :param layout_C: cutlass.LayoutType + :param layout_C: cutlass_cppgen.LayoutType :type layout_D: layout of operand D - :param layout_D: cutlass.LayoutType + :param layout_D: cutlass_cppgen.LayoutType """ def __init__( @@ -151,11 +151,11 @@ class GroupedGemm(Gemm): alignment_B: int = None, alignment_C: int = None) -> GemmOperationGrouped: """ - Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current + Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current kernel specification of the ``Gemm`` object. :param tile_description: tile description specifying shapes and operand types to use in the kernel - :type tile_description: cutlass.backend.TileDescription + :type tile_description: cutlass_cppgen.backend.TileDescription :param alignment_A: alignment of operand A :type alignment_A: int :param alignment_B: alignment of operand B @@ -164,7 +164,7 @@ class GroupedGemm(Gemm): :type alignment_C: int :return: operation that was constructed - :rtype: cutlass.backend.GemmOperationGrouped + :rtype: cutlass_cppgen.backend.GemmOperationGrouped """ alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A"))) alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B"))) @@ -225,7 +225,7 @@ class GroupedGemm(Gemm): :type stream: :class:`cuda.cuda.CUstream` :return: arguments passed in to the kernel - :rtype: cutlass.backend.GemmGroupedArguments + :rtype: cutlass_cppgen.backend.GemmGroupedArguments """ if not stream: stream = cuda.CUstream(0) diff --git a/python/cutlass/op/op.py b/python/cutlass/op/op.py index 444df8b9..88ccd26e 100644 --- a/python/cutlass/op/op.py +++ b/python/cutlass/op/op.py @@ -44,14 +44,14 @@ from cutlass_library import ( SharedMemPerCC ) -import cutlass -from cutlass import get_option_registry -from cutlass.backend.evt import EpilogueFunctorVisitor -from cutlass.backend.utils.device import device_cc -from cutlass.epilogue import get_activations, get_activation_epilogue, identity -from cutlass.library_defaults import KernelsForDataType, _generator_ccs -from cutlass.swizzle import get_swizzling_functors -from cutlass.utils import datatypes, check +import cutlass_cppgen +from cutlass_cppgen import get_option_registry +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity +from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs +from cutlass_cppgen.swizzle import get_swizzling_functors +from cutlass_cppgen.utils import datatypes, check class OperationBase: @@ -205,19 +205,19 @@ class OperationBase: return tensor @property - def opclass(self) -> cutlass.OpcodeClass: + def opclass(self) -> cutlass_cppgen.OpcodeClass: """ Returns the opcode class currently in use :return: opcode class currently in use - :rtype: cutlass.OpcodeClass + :rtype: cutlass_cppgen.OpcodeClass """ return self.op_class @opclass.setter - def opclass(self, oc: cutlass.OpcodeClass): + def opclass(self, oc: cutlass_cppgen.OpcodeClass): if isinstance(oc, str): - oc = datatypes.getattr_enum(cutlass.OpcodeClass, oc) + oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc) if oc in self.possible_op_classes: self.op_class = oc else: @@ -236,25 +236,25 @@ class OperationBase: self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor) @property - def math_operation(self) -> cutlass.MathOperation: + def math_operation(self) -> cutlass_cppgen.MathOperation: """ Returns the math operation currently in use :return: math operation currently in use - :rtype: cutlass.MathOperation + :rtype: cutlass_cppgen.MathOperation """ return self._math_operation @math_operation.setter - def math_operation(self, mo: cutlass.MathOperation): + def math_operation(self, mo: cutlass_cppgen.MathOperation): if isinstance(mo, str): - mo = datatypes.getattr_enum(cutlass.MathOperation, mo) + mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo) if not self.specified_kernel_cc: if self.current_cc == 90: # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. - cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") self._reset_options(80) self._reset_operations(reset_epilogue=False) elif self.current_cc == 90: @@ -266,7 +266,7 @@ class OperationBase: self._reset_operations() def _elements_per_access(self): - if self.op_class == cutlass.OpcodeClass.Simt: + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: return 1 elif self._element_c != DataType.void: return 128 // DataTypeSize[self._element_c] @@ -286,7 +286,7 @@ class OperationBase: if self.current_cc == 90 and activation != identity: # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. - cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") if self._element_c != self._element_d: raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") self._reset_options(80) @@ -361,7 +361,7 @@ class OperationBase: """ if isinstance(act, tuple): if isinstance(act[0], str): - act_fn = getattr(cutlass.backend.epilogue, act[0]) + act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0]) else: act_fn = act[0] self._reset_epilogue_functor_activation(act_fn) @@ -369,7 +369,7 @@ class OperationBase: self._activation = act[0] else: if isinstance(act, str): - act = getattr(cutlass.backend.epilogue, act) + act = getattr(cutlass_cppgen.backend.epilogue, act) self._reset_epilogue_functor_activation(act) self._activation = act @@ -401,8 +401,8 @@ class OperationBase: td = datatypes.td_from_profiler_op(operation) # Filter invalid epilogue schedules if td.epilogue_schedule not in [ - cutlass.EpilogueScheduleType.TmaWarpSpecialized, - cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative]: + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized, + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]: continue epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td) @@ -427,4 +427,4 @@ class OperationBase: Steps that must be taken before caling `plan.run()` """ # Initialize the memory pool if, if not already done - cutlass.get_memory_pool() + cutlass_cppgen.get_memory_pool() diff --git a/python/cutlass/shape.py b/python/cutlass/shape.py index 0987899a..a718f9bb 100644 --- a/python/cutlass/shape.py +++ b/python/cutlass/shape.py @@ -39,7 +39,7 @@ from cutlass_library import ( ConvKind, LayoutType ) -from cutlass.backend.c_types import ( +from cutlass_cppgen.backend.c_types import ( Conv2DProblemSize_, GemmCoord_, GemmCoordBatched_ diff --git a/python/cutlass/utils/__init__.py b/python/cutlass/utils/__init__.py index 21658035..75d8416a 100644 --- a/python/cutlass/utils/__init__.py +++ b/python/cutlass/utils/__init__.py @@ -30,7 +30,7 @@ # ################################################################################################# -from cutlass.utils.check import ( +from cutlass_cppgen.utils.check import ( alignment_or_default, calculate_smem_usage, calculate_smem_usage_per_stage, diff --git a/python/cutlass/utils/check.py b/python/cutlass/utils/check.py index 7cc004ec..ff76a42b 100644 --- a/python/cutlass/utils/check.py +++ b/python/cutlass/utils/check.py @@ -38,8 +38,8 @@ import ctypes from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC -import cutlass -from cutlass.backend.library import TileDescription +import cutlass_cppgen +from cutlass_cppgen.backend.library import TileDescription def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int: @@ -82,8 +82,8 @@ def valid_stage_count( cc: int, kernel_cc: int, td: TileDescription, - element_C: cutlass.DataType = None, - element_D: cutlass.DataType = None, + element_C: cutlass_cppgen.DataType = None, + element_D: cutlass_cppgen.DataType = None, verbose: bool = True) -> tuple: """ Checks whether a device with `cc` supports the number of stages within `tile_description`, both @@ -96,9 +96,9 @@ def valid_stage_count( :param td: tile description to check :type td: TileDescription :param element_C: data type of operand C - :type element_C: cutlass.DataType + :type element_C: cutlass_cppgen.DataType :param element_D: data type of operand D - :type element_D: cutlass.DataType + :type element_D: cutlass_cppgen.DataType :param verbose: whether to log warnings :type verbose: bool @@ -112,7 +112,7 @@ def valid_stage_count( # determines the stage count to use. Thus, all settings are valid in these scenarios. return (True, "") elif verbose: - cutlass.logger.warning( + cutlass_cppgen.logger.warning( "Setting an explicit stage count for SM90 kernels currently may " "result in compilation errors if the combination of tile shape, " "stage count, and shared memory requirement of the epilogue exceeds " @@ -188,9 +188,9 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: def valid_schedule( cc: int, - kernel_schedule: cutlass.KernelScheduleType, - epilogue_schedule: cutlass.EpilogueScheduleType, - tile_scheduler: cutlass.TileSchedulerType) -> tuple: + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple: """ Checks that the kernel and epilogue schedules passed in are a valid combination for a device of compute capability ``cc``. @@ -198,19 +198,19 @@ def valid_schedule( :param cc: compute capability of device in question :type cc: int :param kernel_schedule: kernel schedule type - :type kernel_schedule: cutlass.KernelScheduleType + :type kernel_schedule: cutlass_cppgen.KernelScheduleType :param epilogue_schedule: epilogue schedule type - :type epilogue_schedule: cutlass.EpilogueScheduleType + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType :param tile_scheduler: tile scheduler type - :type tile_scheduler: cutlass.TileSchedulerType + :type tile_scheduler: cutlass_cppgen.TileSchedulerType :return: tuple with the first element indicating whether the provided schedules are valid for the provided device and the second element being an error message :rtype: tuple """ - kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto) - epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto) - tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default) + kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto) + epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto) + tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default) if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default): return (False, "Non-default schedules are only supported on SM90 and beyond") @@ -218,9 +218,9 @@ def valid_schedule( return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") if not tile_scheduler_default: - cooperative_kernels = [cutlass.KernelScheduleType.TmaWarpSpecializedCooperative, - cutlass.KernelScheduleType.CpAsyncWarpSpecializedCooperative] - if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): + cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative] + if (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") return (True, "") diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py index 75beda65..c03a834d 100644 --- a/python/cutlass/utils/datatypes.py +++ b/python/cutlass/utils/datatypes.py @@ -34,13 +34,13 @@ Utility functions for converting between frontend datatypes and CUTLASS datatypes """ -import cutlass +import cutlass_cppgen from cutlass_library import ( DataTypeSize, MathOperation, MathInstruction ) -from cutlass.backend.library import ( +from cutlass_cppgen.backend.library import ( TileDescription, ) @@ -62,11 +62,11 @@ def is_numpy_available(): numpy_available = True _library_to_numpy_dict = { - cutlass.DataType.f16: np.float16, - cutlass.DataType.f32: np.float32, - cutlass.DataType.f64: np.float64, - cutlass.DataType.s8: np.int8, - cutlass.DataType.s32: np.int32, + cutlass_cppgen.DataType.f16: np.float16, + cutlass_cppgen.DataType.f32: np.float32, + cutlass_cppgen.DataType.f64: np.float64, + cutlass_cppgen.DataType.s8: np.int8, + cutlass_cppgen.DataType.s32: np.int32, } except ImportError: numpy_available = False @@ -81,19 +81,19 @@ def is_numpy_tensor(inp) -> bool: return False -def numpy_library_type(inp) -> cutlass.DataType: +def numpy_library_type(inp) -> cutlass_cppgen.DataType: if is_numpy_available(): import numpy as np if inp == np.float16: - return cutlass.DataType.f16 + return cutlass_cppgen.DataType.f16 elif inp == np.float32: - return cutlass.DataType.f32 + return cutlass_cppgen.DataType.f32 elif inp == np.float64: - return cutlass.DataType.f64 + return cutlass_cppgen.DataType.f64 elif inp == np.int8: - return cutlass.DataType.s8 + return cutlass_cppgen.DataType.s8 elif inp == np.int32: - return cutlass.DataType.s32 + return cutlass_cppgen.DataType.s32 return None @@ -109,11 +109,11 @@ def is_cupy_available(): cupy_available = True _library_to_cupy_dict = { - cutlass.DataType.f16: cp.float16, - cutlass.DataType.f32: cp.float32, - cutlass.DataType.f64: cp.float64, - cutlass.DataType.s8: cp.int8, - cutlass.DataType.s32: cp.int32, + cutlass_cppgen.DataType.f16: cp.float16, + cutlass_cppgen.DataType.f32: cp.float32, + cutlass_cppgen.DataType.f64: cp.float64, + cutlass_cppgen.DataType.s8: cp.int8, + cutlass_cppgen.DataType.s32: cp.int32, } except ImportError: cupy_available = False @@ -128,15 +128,15 @@ def is_cupy_tensor(inp) -> bool: return False -def cupy_library_type(inp) -> cutlass.DataType: +def cupy_library_type(inp) -> cutlass_cppgen.DataType: if is_cupy_available(): import cupy as cp if inp == cp.float16: - return cutlass.DataType.f16 + return cutlass_cppgen.DataType.f16 elif inp == cp.float32: - return cutlass.DataType.f32 + return cutlass_cppgen.DataType.f32 elif inp == cp.float64: - return cutlass.DataType.f64 + return cutlass_cppgen.DataType.f64 return None @@ -152,29 +152,29 @@ def is_torch_available(): torch_available = True _torch_to_library_dict = { - torch.half: cutlass.DataType.f16, - torch.float16: cutlass.DataType.f16, - torch.bfloat16: cutlass.DataType.bf16, - torch.float: cutlass.DataType.f32, - torch.float32: cutlass.DataType.f32, - torch.double: cutlass.DataType.f64, - torch.float64: cutlass.DataType.f64, - torch.int8: cutlass.DataType.s8, - torch.int32: cutlass.DataType.s32, - torch.uint8: cutlass.DataType.u8, + torch.half: cutlass_cppgen.DataType.f16, + torch.float16: cutlass_cppgen.DataType.f16, + torch.bfloat16: cutlass_cppgen.DataType.bf16, + torch.float: cutlass_cppgen.DataType.f32, + torch.float32: cutlass_cppgen.DataType.f32, + torch.double: cutlass_cppgen.DataType.f64, + torch.float64: cutlass_cppgen.DataType.f64, + torch.int8: cutlass_cppgen.DataType.s8, + torch.int32: cutlass_cppgen.DataType.s32, + torch.uint8: cutlass_cppgen.DataType.u8, } _library_to_torch_dict = { - cutlass.DataType.f16: torch.half, - cutlass.DataType.f16: torch.float16, - cutlass.DataType.bf16: torch.bfloat16, - cutlass.DataType.f32: torch.float, - cutlass.DataType.f32: torch.float32, - cutlass.DataType.f64: torch.double, - cutlass.DataType.f64: torch.float64, - cutlass.DataType.s8: torch.int8, - cutlass.DataType.s32: torch.int32, - cutlass.DataType.u8: torch.uint8, + cutlass_cppgen.DataType.f16: torch.half, + cutlass_cppgen.DataType.f16: torch.float16, + cutlass_cppgen.DataType.bf16: torch.bfloat16, + cutlass_cppgen.DataType.f32: torch.float, + cutlass_cppgen.DataType.f32: torch.float32, + cutlass_cppgen.DataType.f64: torch.double, + cutlass_cppgen.DataType.f64: torch.float64, + cutlass_cppgen.DataType.s8: torch.int8, + cutlass_cppgen.DataType.s32: torch.int32, + cutlass_cppgen.DataType.u8: torch.uint8, } def possibly_add_type(torch_type_name, cutlass_type): @@ -184,8 +184,8 @@ def is_torch_available(): _torch_to_library_dict[torch_type] = cutlass_type _library_to_torch_dict[cutlass_type] = torch_type - possibly_add_type("float8_e4m3fn", cutlass.DataType.e4m3) - possibly_add_type("float8_e5m2", cutlass.DataType.e5m2) + possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3) + possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2) except ImportError: torch_available = False @@ -201,7 +201,7 @@ def is_torch_tensor(inp) -> bool: return False -def torch_library_type(inp) -> cutlass.DataType: +def torch_library_type(inp) -> cutlass_cppgen.DataType: return _torch_to_library_dict.get(inp, None) @@ -222,17 +222,17 @@ def is_bfloat16_available(): return bfloat16_available -def bfloat16_library_type(inp) -> cutlass.DataType: +def bfloat16_library_type(inp) -> cutlass_cppgen.DataType: if is_bfloat16_available(): import bfloat16 if inp == bfloat16.bfloat16: - return cutlass.DataType.bf16 + return cutlass_cppgen.DataType.bf16 def bfloat16_type(inp): if is_bfloat16_available(): import bfloat16 - if inp == cutlass.DataType.bf16: + if inp == cutlass_cppgen.DataType.bf16: return bfloat16.bfloat16 @@ -256,15 +256,15 @@ def library_type(inp): def _tensor_from_numpy(np_tensor): dtype = library_type(np_tensor.dtype) if np_tensor.flags.c_contiguous: - layout = cutlass.LayoutType.RowMajor + layout = cutlass_cppgen.LayoutType.RowMajor elif np_tensor.flags.f_contiguous: - layout = cutlass.LayoutType.ColumnMajor + layout = cutlass_cppgen.LayoutType.ColumnMajor return (dtype, layout) def _tensor_from_torch(pt_tensor): dtype = library_type(pt_tensor.dtype) - return (dtype, cutlass.LayoutType.RowMajor) + return (dtype, cutlass_cppgen.LayoutType.RowMajor) def get_datatype_and_layout(tensor): @@ -273,7 +273,7 @@ def get_datatype_and_layout(tensor): elif is_torch_tensor(tensor): return _tensor_from_torch(tensor) elif isinstance(tensor, float) or isinstance(tensor, int): - return (cutlass.DataType.f32, cutlass.LayoutType.RowMajor) + return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor) else: raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") @@ -303,10 +303,10 @@ def backend_math_operation(math_op: MathOperation): return _math_operation_value_map[math_op.value] -def construct_backend_td(td: cutlass.TileDescription, - kernel_schedule: cutlass.KernelScheduleType, - epilogue_schedule: cutlass.EpilogueScheduleType, - tile_scheduler: cutlass.TileSchedulerType) -> TileDescription: +def construct_backend_td(td: cutlass_cppgen.TileDescription, + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription: mi = td.math_instruction backend_mi = MathInstruction( mi.instruction_shape, @@ -328,7 +328,7 @@ def td_from_profiler_op(op) -> TileDescription: :param op: profiler Operation :returns: backend TileDescription - :rtype: cutlass.backend.TileDescription + :rtype: cutlass_cppgen.backend.TileDescription """ kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None @@ -341,10 +341,10 @@ def td_from_profiler_td(td: TileDescription) -> TileDescription: Converts the profiler's TileDescription into the backend TileDescription :param td: profiler TileDescription - :type td: cutlass.TileDescription + :type td: cutlass_cppgen.TileDescription :returns: backend TileDescription - :rtype: cutlass.backend.TileDescription + :rtype: cutlass_cppgen.backend.TileDescription """ return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None) diff --git a/python/cutlass/utils/profiler.py b/python/cutlass/utils/profiler.py index 5733f3ba..f53b1567 100644 --- a/python/cutlass/utils/profiler.py +++ b/python/cutlass/utils/profiler.py @@ -37,16 +37,16 @@ Profiler based on the cuda events import re import subprocess -from cutlass.utils.lazy_import import lazy_import +from cutlass_cppgen.utils.lazy_import import lazy_import cuda = lazy_import("cuda.cuda") cudart = lazy_import("cuda.cudart") import numpy as np -from cutlass import CUTLASS_PATH -from cutlass.backend.library import DataTypeSize -from cutlass.op.op import OperationBase -from cutlass.shape import GemmCoord -from cutlass.utils.datatypes import is_numpy_tensor +from cutlass_cppgen import CUTLASS_PATH +from cutlass_cppgen.backend.library import DataTypeSize +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import is_numpy_tensor class GpuTimer: diff --git a/python/cutlass_library/__init__.py b/python/cutlass_library/__init__.py index 3d95edb5..534eef47 100644 --- a/python/cutlass_library/__init__.py +++ b/python/cutlass_library/__init__.py @@ -49,7 +49,6 @@ from . import rank_2k_operation from . import rank_k_operation from . import symm_operation from . import trmm_operation - # Make enum types from library.py accessible via cutlass_library.* from .library import * diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 4ae2d8ed..3121b7b0 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -279,7 +279,7 @@ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode ): # For functional testing, we prefer to run reference computing on device if any - reference_device_archs = ["100a"] + reference_device_archs = ["100a", "103a"] run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False profiler_flags_for_verification = "device" if run_reference_on_device else "host" @@ -287,7 +287,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode # TODO: randomize beta values for wider coverage beta_values = [0.5] - is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "120a", "120f"]) + is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"]) is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch @@ -306,6 +306,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode 'bf16gemm_f32_f32_f32_f32_f32', ] + exclude_archs = arch not in ("103a") + if exclude_archs: + sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8') + sm100_mma_data_type_runtime_dtype = [ 'gemm.*f4_f4_f32_f32_f32', 'gemm.*f6_f6_f32_f32_f32', @@ -344,6 +348,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', ] + sm103_block_scaled_data_type = [ + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + ] + block_scaled_cluster_size = [ '4x4x1', '2x1x1', '0x0x1' # dynamic cluster @@ -354,6 +363,9 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + if arch in ["100a", "100f"]: kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ f"({sm100_mma_filter_regex_2sm})|" \ @@ -361,15 +373,23 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode f"({sm100_mma_filter_regex_2sm_runtime})|" \ f"({block_scaled_filter_regex_1sm})|" \ f"({block_scaled_filter_regex_2sm})" - elif arch in ["101a", "101f", - ]: + elif arch in ["101a", "101f", "110a", "110f"]: kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ f"({sm100_mma_filter_regex_2sm})|" \ f"({sm100_mma_filter_regex_1sm_runtime})|" \ f"({sm100_mma_filter_regex_2sm_runtime})|" \ f"({block_scaled_filter_regex_1sm})|" \ f"({block_scaled_filter_regex_2sm})" - elif arch in ["120a", "120f"]: + elif arch in ["103a"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})|" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" + elif arch in ["120a", "120f", "121a", "121f"]: # blockscaled sm120_mma kernels blockscaled_sm120_mma_kernel_cta_tiles = [ @@ -384,7 +404,7 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode kernel_filter = f"({filter_regex_blockscaled_sm120_mma})" else: - error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm120a, sm120f" + error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f" raise Exception(error_message) elif mode == "functional_L1": @@ -403,16 +423,27 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode 'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', ] - block_scaled_cluster_size = ['4x4x1', '2x1x1', '0x0x1'] + sm103_block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + ] + + block_scaled_cluster_size = ['0x0x1'] block_scaled_layouts = ['tnt'] # regex list must be in kernel procedural name order block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \ f"({sm100_mma_filter_regex_2sm})|" \ f"({block_scaled_filter_regex_1sm})|" \ - f"({block_scaled_filter_regex_2sm})" + f"({block_scaled_filter_regex_2sm})" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times sm120_mma_kernel_cta_tiles = [ # h1688, s1688, i16832, i8816 @@ -449,7 +480,10 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode problem_waves = [0.5, 1.25, 2.5] - kernel_filter = f"({filter_regex_sm100_mma})|({filter_regex_sm120_mma})" + if arch in ["120a", "120f", "121a", "121f"]: + kernel_filter = f"({filter_regex_sm120_mma})" + else: + kernel_filter = f"({filter_regex_sm100_mma})" else: raise ValueError() diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index ce0b16f9..1d247625 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -341,7 +341,7 @@ class GemmOperation: Get the tile shape passed to the collective builder. On Blackwell, this is different than the operation.tile_description.tile_shape. """ - is_sm100_kernel = (self.arch == 100) + is_sm100_kernel = (self.arch == 100 or self.arch == 103) if not is_sm100_kernel: return self.tile_description.tile_shape @@ -995,6 +995,24 @@ ${compile_guard_end} epi_tile_mn = "cute::Shape" if not is_no_smem_epilogue: epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] + if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] + if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] + + if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] + if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] + element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 0cdb2155..19cef8c7 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -90,10 +90,12 @@ try: raise ImportError("Disabling attempt to import cutlass_library") from cutlass_library.library import * from cutlass_library.manifest import * + from cutlass_library.heuristics import * from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist except ImportError: from library import * from manifest import * + from heuristics import * from emit_kernel_listing import emit_gemm_kernel_testlist ################################################################################################### @@ -112,6 +114,10 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): cuda_version.append(x) return cuda_version >= [major, minor, patch] +# From cuda 13.0, Thor SM is renumbered from 101 to 110 +def ThorSMRenumbering(cuda_version): + return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101 + ################################################################################################### ################################################################################################### @@ -6768,9 +6774,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): }, ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + math_instructions_1sm = [ # tf32 -> f32 MathInstruction( @@ -6887,7 +6895,8 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm grouped = is_grouped(gemm_kind) @@ -7202,9 +7211,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + epi_type = DataType.f32 grouped = is_grouped(gemm_kind) @@ -7889,9 +7900,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): TileSchedulerType.Default, TileSchedulerType.StreamK ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + epi_type = DataType.f32 math_instructions_1sm = [] @@ -8092,6 +8105,8 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + grouped = is_grouped(gemm_kind) + layouts = [ [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], @@ -8120,14 +8135,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud def tile_schedulers(sfdtype): # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, # the epilogue is the traditional linear combination, for which we already have tests with stream-K. - if sfdtype["type"] == DataType.void: + if sfdtype["type"] == DataType.void or grouped: return [TileSchedulerType.Default] else: return [TileSchedulerType.Default, TileSchedulerType.StreamK] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + epi_type = DataType.f32 math_instructions_1sm = [] @@ -8209,6 +8226,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, @@ -8246,7 +8273,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud for data_type in data_types: CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]] + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]] , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) cluster_shapes_2sm = [ @@ -8288,6 +8315,16 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, @@ -8346,7 +8383,11 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0: continue - if math_inst.instruction_shape[0] == 128: + if grouped: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + elif math_inst.instruction_shape[0] == 128: CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]] , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) @@ -8396,9 +8437,11 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio else: return [TileSchedulerType.Default, TileSchedulerType.StreamK] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + epi_type = DataType.f32 math_instructions_1sm = [] @@ -8496,6 +8539,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, @@ -8625,6 +8678,16 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio "sf_type" : math_inst.element_scale_factor, "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, { "a_type" : math_inst.element_a, "b_type" : math_inst.element_b, @@ -8715,6 +8778,230 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio +def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + ] + + instruction_sizes_1sm = [ + [128, 128, 96], + ] + + instruction_sizes_2sm = [ + [256, 128, 96], + ] + + ab_types = [ + DataType.e2m1, + ] + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + min_cc = 103 + max_cc = 103 + epi_type = DataType.f32 + + math_instructions_1sm = [] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) # UE8M0 scale factor + ) + + math_instructions_2sm = [] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) # UE8M0 scale factor + ) + + cluster_shapes_1sm = [ + [1,1,1], + # [1,2,1], + [2,1,1], + # [1,4,1], + [4,4,1] + , DynamicClusterShape + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + 768], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] + fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] + fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized1Sm] + # For FP4 inputs + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch + ,fp4_schedule_enable_prefetch + ] + , gemm_kind=gemm_kind + ) + + cluster_shapes_2sm = [ + [2,1,1], + # [2,2,1], + # [2,4,1], + [4,1,1], + # [4,2,1], + [4,4,1] + , DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + fp4_schedule = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] + fp4_schedule_disable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] + fp4_schedule_enable_prefetch = [KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] + # For FP4 inputs + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule, fp4_schedule_disable_prefetch + ,fp4_schedule_enable_prefetch + ] + , gemm_kind=gemm_kind + ) def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): @@ -8732,7 +9019,8 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm @@ -8948,9 +9236,11 @@ def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + tile_schedulers = [ TileSchedulerType.Default, ] @@ -9074,9 +9364,11 @@ def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + tile_schedulers = [ TileSchedulerType.Default, ] @@ -9200,7 +9492,8 @@ def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm @@ -9326,9 +9619,11 @@ def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + tile_schedulers = [ TileSchedulerType.Default, ] @@ -9465,9 +9760,11 @@ def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + tile_schedulers = [ TileSchedulerType.Default, ] @@ -9678,9 +9975,11 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): } ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + math_instructions_1sm = [ MathInstruction( [128, 256, 8], @@ -9772,9 +10071,11 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + math_instructions_1sm = [ MathInstruction( [128, 256, 16], @@ -9934,9 +10235,11 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version): [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], ] - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + min_cc = 100 max_cc = thor_sm + epi_type = DataType.f32 math_instructions_1sm = [ @@ -10084,7 +10387,8 @@ def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + minimum_compute_capability = 100 maximum_compute_capability = thor_sm @@ -10238,7 +10542,8 @@ def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return - thor_sm = 101 + thor_sm = ThorSMRenumbering(cuda_version) + minimum_compute_capability = 100 maximum_compute_capability = thor_sm @@ -10422,7 +10727,7 @@ def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud return [TileSchedulerType.Default, TileSchedulerType.StreamK] min_cc = 120 - max_cc = 120 + max_cc = 121 epi_type = DataType.f32 @@ -10567,7 +10872,7 @@ def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio return [TileSchedulerType.Default, TileSchedulerType.StreamK] min_cc = 120 - max_cc = 120 + max_cc = 121 epi_type = DataType.f32 @@ -10720,7 +11025,7 @@ def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): return [TileSchedulerType.Default] min_cc = 120 - max_cc = 120 + max_cc = 121 kernel_schedules = [ KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120, @@ -10840,7 +11145,7 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, return [TileSchedulerType.Default] min_cc = 120 - max_cc = 120 + max_cc = 121 kernel_schedulers = [ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120, @@ -10924,7 +11229,11 @@ def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind = gemm_kind) def GenerateSM100(manifest, cuda_version): - arch_family_cc = ['100f', '101f'] + arch_family_cc = ['100f', '101f', '103a'] + if CudaToolkitVersionSatisfies(cuda_version, 13, 0): + for old_cc, new_cc in [('101f', '110f')]: + arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc] + # # Dense Gemm # @@ -10966,8 +11275,11 @@ def GenerateSM100(manifest, cuda_version): # Block Scaled Gemm # GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) # # Conv # @@ -11413,7 +11725,6 @@ def numeric_log_level(log_level: str) -> int: raise ValueError(f'Invalid log level: {log_level}') return numeric_level - # This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface # to leverage the functionality in this file without running this script via a shell prompt. def define_parser(): @@ -11438,6 +11749,11 @@ def define_parser(): parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.') parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') + parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list') + parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler') + parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000']) + parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list') + parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py') parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, help='Specify the output log file containing all enabled kernels in this build') parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") @@ -11460,6 +11776,9 @@ if __name__ == "__main__": archs = args.architectures.split(';') + if args.heuristics_problems_file: + filter_manifest_and_write_heuristics_file(manifest, args) + GenerateSM50(manifest, args.cuda_version) GenerateSM60(manifest, args.cuda_version) GenerateSM61(manifest, args.cuda_version) @@ -11468,17 +11787,20 @@ if __name__ == "__main__": GenerateSM80(manifest, args.cuda_version) GenerateSM89(manifest, args.cuda_version) GenerateSM90(manifest, args.cuda_version) - + blackwell_arch_list = [ "100a", "100f", "101a", "101f", - "120a", "120f" + "103a", "103f", + "110a", "110f", + "120a", "120f", + "121a", "121f", ] blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs) if blackwell_enabled_arch: GenerateSM100(manifest, args.cuda_version) GenerateSM120(manifest, args.cuda_version) - + if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) diff --git a/python/cutlass_library/heuristics.py b/python/cutlass_library/heuristics.py new file mode 100644 index 00000000..dc69e103 --- /dev/null +++ b/python/cutlass_library/heuristics.py @@ -0,0 +1,414 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for selecting CUTLASS library kernels based on problem description +""" +import json +import csv + +try: + if CUTLASS_IGNORE_PACKAGE: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.generator import * + from cutlass_library.heuristics_provider import * +except ImportError: + from library import * + from generator import * + from heuristics_provider import * + +try: + from .sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) +except ImportError: + from sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) + +_LOGGER = logging.getLogger(__name__) + +dtype_map = {v: k for k, v in DataTypeNames.items()} + +def serialize_heuristics_results_to_json(problems_with_configs, outfile_path): + """ + Utilitiy function to write heuristics results to a json file for debug + + args: + problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict + outfile_path: Outfile path + + returns: + None + """ + pc_copy = problems_with_configs.copy() + for p in pc_copy: + for k, v in p.items(): + if isinstance(v, DataType): + p[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + p[k] = ShortLayoutTypeNames[v] + configs = p['configs'] + for c in configs: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + with open(outfile_path, 'w') as f: + json.dump(pc_copy, f, indent=2) + +def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None): + """ + Get heuristic-suggested GEMM kernel configurations for a single GEMM problem. + + args: + m, n, k: GEMM dimensions + batch_count: batch count + layouts: tuple of layouts of type LayoutType + use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions + count: Number of configs to return + provider: Heuristics provider to use + + returns: + A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys: + - 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size + - 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size + - 'stages': kernel pipeline stage count + - 'cluster_m', 'cluster_n', 'cluster_k': cluster size + - 'layout_a', 'layout_b': input tensor layouts of type LayoutType + - 'alignment_a', 'alignment_b': input tensor alignments, in count of elements + - 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType + - 'swizzle_size' : suggested threadblock swizzle + - 'split_k_slices': number of partitions of the k dimension for splitK + - 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n') + """ + if provider is None: + provider = MatmulHeuristics() + return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count) + +def get_gemm_configs(problems, provider=None, count=1): + """ + Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems. + + args: + problems: List of dictionaries describing GEMM problems with the following keys: + - 'm', 'n', 'k': Matrix dimensions (required) + - 'dtype_a': Data type of matrix A (required) + - 'dtype_b': Data type of matrix B (required) + - 'dtype_c': Data type of matrix C (default: None) + - 'dtype_d': Data type of matrix D (required) + - 'dtype_acc': Compute data type (default 'f32') + - 'layout': Operation layout (e.g. 'tnt') + - 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements) + - 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements) + - 'alpha': Scalar multiplier for A*B (default: 1.0) + - 'beta': Scalar multiplier for C (default: 0.0) + - 'batch_count': Number of GEMM operations in batch (default: 1) + - 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True) + provider: Heuristics provider to use + count: Number of configurations to return per problem (defualt: 1) + + returns: + A copy of the input dictionary, with key `configs` added containing the selected gemm configs + """ + ret = [] + + for problem in problems: + problem = problem.copy() + + try: + m = problem['m'] + n = problem['n'] + k = problem['k'] + dtype_a = problem['dtype_a'] + dtype_b = problem['dtype_b'] + dtype_d = problem['dtype_d'] + layout = problem['layout'] + except KeyError as e: + _LOGGER.error(f"Missing required parameter {e} for problem {problem}") + raise + + operation = problem.get('operation', 'gemm') + batch_count = problem.get('batch_count', 1) + dtype_acc = problem.get('dtype_acc', 'f32') + dtype_c = problem.get('dtype_c', None) + alpha = problem.get('alpha', 1.0) + beta = problem.get('beta', 0.0) + use_fast_acc = problem.get('use_fast_acc', True) + + if operation != OperationKindNames[OperationKind.Gemm]: + raise ValueError(f"Unsupported operation {operation}") + if not (len(layout) == 3 and all(c in "nt" for c in layout)): + raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}") + layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout) + + try: + dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()] + dtypes = tuple(dtype_map[dt] for dt in dtype_list) + except KeyError as dt: + _LOGGER.error(f"Unsupported data type: {dt}") + raise + + alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]]) + alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]]) + + configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider) + problem['configs'] = configs + + ret.append(problem) + + return ret + + +def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc = 100 + max_cc = 101 + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # nvMMH assumes 2sm instruction for !(cluster_m % 2) + is_2sm = config['cluster_m'] % 2 == 0 + instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4] + math_instruction = MathInstruction( + instruction_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = [ + { + "a_type" : math_instruction.element_a, + "b_type" : math_instruction.element_b, + "c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator, + "d_type" : element_d, + "acc_type" : math_instruction.element_accumulator, + "epi_type" : math_instruction.element_accumulator, + } + ] + + tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k']) + tile_description = TileDescription( + [instruction_shape[0] * tile_multiplier[0], + instruction_shape[1] * tile_multiplier[1], + instruction_shape[2] * 4 * tile_multiplier[2]], + 0, + [4,1,1], + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules = [] + if is_2sm: + schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]) + else: + schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]) + + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + + return configs, operations + + +def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc, max_cc = 90, 90 + + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + + is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128) + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_instr_shape = [0, 0, 0] + math_instruction = MathInstruction( + dummy_instr_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d) + if is_aligned: + layout = fix_alignments(data_types, layout, alignment_bits=128) + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_warp_count = [0, 0, 0] + tile_description = TileDescription( + [config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']], + 0, + dummy_warp_count, + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_description, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_types, + instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic + layout=layout, + gemm_kind=GemmKind.Universal3x, + enable_fp8_fast_acc=config['use_fast_acc'] + ) + + if len(schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + if len(stream_k_schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]): + configs.append(config) + operations.append(o) + + + return configs, operations + +def filter_manifest_and_write_heuristics_file(manifest, args): + """ + Prune a manifest according to heuristics suggestions from the problems file + + args: + manifest: Cutlass manifest to prune + args: generator.py args, requires: + - args.heuristics_problems_file + - args.heuristics_gpu + - args.heuristics_testlist_file + + returns: + A list of dictionaries, each of which has information about an operation and a problem from the input problems + """ + heuristics_problems = [] + with open(args.heuristics_problems_file, 'r') as f: + heuristics_problems = json.load(f) + gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu + mmh = MatmulHeuristics(gpu=gpu) + if any(('100' in arch) for arch in args.architectures.split(';')): + mmh.set_cta_div_n(64) + problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem) + + all_configs_and_operations = [] + operations = [] + for problem in problems_with_configs: + if any('90' in arch for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + + operations += problem_operations + problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'} + with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)] + all_configs_and_operations += with_problem_size + + for operation in operations: + manifest.add_kernel_filter(f"^{operation.procedural_name()}$") + if not all_configs_and_operations: + raise Exception("No valid configurations generated") + write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file) + return all_configs_and_operations + +def write_profiler_testlist_to_csv(configs_list, outfile_path): + """ + Write a list of configs to a testlist to be consumed by cutlass_profiler + + args: + configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries + outfile_path: Outfile path + + returns: + None + """ + profiler_testlist = configs_list.copy() + for c in profiler_testlist: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + + with open(outfile_path, mode='w', newline='') as ofile: + k_names = profiler_testlist[0].keys() + + writer = csv.DictWriter(ofile, fieldnames=k_names) + writer.writeheader() + writer.writerows(profiler_testlist) diff --git a/python/cutlass_library/heuristics_provider.py b/python/cutlass_library/heuristics_provider.py new file mode 100644 index 00000000..9baff5c0 --- /dev/null +++ b/python/cutlass_library/heuristics_provider.py @@ -0,0 +1,168 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Providers for kernel selection heuristics +""" + +import sys +import os +import glob +import logging +import ctypes +import functools + +from library import DataType, LayoutType + +class MatmulHeuristics: + + def __init__(self, gpu = None): + import nvMatmulHeuristics + self.mmh_lib = nvMatmulHeuristics + self.gpu = gpu + + if 'CUTLASS_NVMMH_SO_PATH' in os.environ: + nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH']) + else: + nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx + + self.lh = nvmmhInterfaceEx( + backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"], + flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING, + load_discovery_implicitly=True, + gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None + ) + self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"]) + + def _layout_from_cutlass(self, layouts): + assert(len(layouts)==3) + full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts) + input_layouts = full_layout_str[:2].upper() + lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR") + return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout] + + def _precision_from_cutlass_dtypes(self, dtypes): + dtype_to_cublas = { + DataType.f64: 'D', + DataType.f32: 'S', + DataType.f16: 'H', + DataType.bf16: 'T', + DataType.e4m3: 'Q', + DataType.e5m2: 'R', + DataType.s32: 'I', + DataType.s8: 'B', + } + + dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes + + a_c = dtype_to_cublas[dtype_a] + + if a_c.lower() != 'q': + return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + else: + return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + + def set_cta_div_n(self, div_n): + cta_n_div_requirement = ctypes.c_int(div_n) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, + ctypes.byref(cta_n_div_requirement), + ctypes.sizeof(cta_n_div_requirement) + ) + + def set_cta_div_m(self, div_m): + cta_m_div_requirement = ctypes.c_int(div_m) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, + ctypes.byref(cta_m_div_requirement), + ctypes.sizeof(cta_m_div_requirement) + ) + + def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1): + if use_fast_acc: + disable_fast_acc_for_fp8 = ctypes.c_int(0) + else: + disable_fast_acc_for_fp8 = ctypes.c_int(1) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8, + ctypes.byref(disable_fast_acc_for_fp8), + ctypes.sizeof(disable_fast_acc_for_fp8) + ) + + precision = self._precision_from_cutlass_dtypes(dtypes) + layout = self._layout_from_cutlass(layouts) + + matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count) + configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision) + + ret = [] + for c in configs: + kernel = c['kernel'] + problem = c['problem'] + + r = {} + r['estimated_runtime'] = c['runtime'] + r['cta_tile_m'] = kernel.cta_tile_m + r['cta_tile_n'] = kernel.cta_tile_n + r['cta_tile_k'] = kernel.cta_tile_k + r['instr_tile_m'] = kernel.instr_tile_m + r['instr_tile_n'] = kernel.instr_tile_n + r['instr_tile_k'] = kernel.instr_tile_k + r['warp_tile_m'] = kernel.warp_tile_m + r['warp_tile_n'] = kernel.warp_tile_n + r['warp_tile_k'] = kernel.warp_tile_k + r['cluster_m'] = kernel.cluster_m + r['cluster_n'] = kernel.cluster_n + r['cluster_k'] = 1 + r['layout_a'] = layouts[0] + r['layout_b'] = layouts[1] + r['layout_d'] = layouts[2] + r['dtype_a'] = dtypes[0] + r['dtype_b'] = dtypes[1] + r['dtype_acc'] = dtypes[2] + r['dtype_c'] = dtypes[3] + r['dtype_d'] = dtypes[4] + r['alignment_a'] = align_a + r['alignment_b'] = align_b + r['swizzle_size'] = kernel.swizzle_factor + r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n' + r['split_k_slices'] = kernel.split_k + r['use_fast_acc'] = use_fast_acc + r['voidC'] = voidC + + ret.append(r) + + return ret + diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 17e6c5ce..2e1bd82a 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -546,6 +546,22 @@ class KernelScheduleType(enum.Enum): Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() + # FP4 Ultra + BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto() @@ -603,6 +619,22 @@ KernelScheduleTag = { KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', + # FP4 Ultra + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', @@ -677,6 +709,21 @@ KernelScheduleSuffixes = { KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_1sm', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_2sm', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_1sm', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_2sm', + + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_1sm_nopf', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_2sm_nopf', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_1sm_nopf', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_2sm_nopf', + + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_1sm_tmapf', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_2sm_tmapf', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_1sm_tmapf', + KernelScheduleType.BlockScaledMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_2sm_tmapf', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', @@ -713,8 +760,12 @@ class EpilogueScheduleType(enum.Enum): PtrArrayNoSmemWarpSpecialized = enum_auto() NoSmemWarpSpecialized1Sm = enum_auto() NoSmemWarpSpecialized2Sm = enum_auto() + FastF32NoSmemWarpSpecialized1Sm = enum_auto() + FastF32NoSmemWarpSpecialized2Sm = enum_auto() PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecialized1Sm = enum_auto() @@ -732,8 +783,12 @@ EpilogueScheduleTag = { EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized', EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm', EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', @@ -752,8 +807,12 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem', EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem', EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.TmaWarpSpecialized1Sm: '', diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index 74eea09a..baaaac28 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -526,44 +526,49 @@ class Manifest: if args.filter_by_cc in ['false', 'False', '0']: self.filter_by_cc = False - if args.operations == 'all': - self.operations_enabled = [] - else: - operations_list = [ - OperationKind.Gemm - , OperationKind.Conv2d - , OperationKind.Conv3d - , OperationKind.RankK - , OperationKind.Trmm - , OperationKind.Symm - ] - self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] + if args.operations == 'all': + self.operations_enabled = [] + else: + operations_list = [ + OperationKind.Gemm + , OperationKind.Conv2d + , OperationKind.Conv3d + , OperationKind.RankK + , OperationKind.Trmm + , OperationKind.Symm + ] + self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] - if args.kernels == 'all': - self.kernel_names = [] - else: - self.kernel_names = [x for x in args.kernels.split(',') if x != ''] + if args.kernels == 'all': + self.kernel_names = [] + else: + self.kernel_names = [x for x in args.kernels.split(',') if x != ''] - self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] - self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != ''] + self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] + self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != ''] - if args.kernel_filter_file is None: - self.kernel_filter_list = [] - else: - self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) - _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format( - filter_count = len(self.kernel_filter_list), - filter_file = args.kernel_filter_file)) + if args.kernel_filter_file is None: + self.kernel_filter_list = [] + else: + self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) + _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format( + filter_count = len(self.kernel_filter_list), + filter_file = args.kernel_filter_file)) - self.operation_count = 0 - self.operations_by_name = {} - self.disable_full_archs_compilation = args.disable_full_archs_compilation - self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != '' - self.instantiation_level = 0 - try: - self.instantiation_level = int(args.instantiation_level) - except ValueError: - self.instantiation_level = 0 + self.operation_count = 0 + self.operations_by_name = {} + self.disable_full_archs_compilation = args.disable_full_archs_compilation + self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != '' + self.instantiation_level = 0 + try: + self.instantiation_level = int(args.instantiation_level) + except ValueError: + self.instantiation_level = 0 + + def add_kernel_filter(self, filter_str): + filter_re = re.compile(filter_str) + + self.kernel_filter_list.append(filter_re) def get_sm90_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): # Non-negative integer which determines how many kernels are instantiated. diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 8ea870ec..3bf3edb2 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -407,7 +407,7 @@ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: def is_tile_desc_compatible_with_cooperative(tile_description): # Cooperative kernels require a minimum CTA-M of 128 - return tile_description.threadblock_shape[0] >= 128 + return tile_description.threadblock_shape[0] % 128 == 0 def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types): diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index bdef2e32..8122b7a6 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -50,17 +50,17 @@ setup_pycute.perform_setup() setup( - name='cutlass', - version='3.4.0', + name='cutlass_cppgen', + version='4.0.0', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ - 'cutlass', - 'cutlass.emit', - 'cutlass.op', - 'cutlass.utils', - 'cutlass.backend', - 'cutlass.backend.utils' + 'cutlass_cppgen', + 'cutlass_cppgen.emit', + 'cutlass_cppgen.op', + 'cutlass_cppgen.utils', + 'cutlass_cppgen.backend', + 'cutlass_cppgen.backend.utils' ], setup_requires=['pybind11'], install_requires=[ diff --git a/setup.cfg b/setup.cfg index a4216c2b..98791943 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,26 +1,26 @@ [metadata] name = nvidia-cutlass -version = 3.4.0.0 +version = 4.0.0.0 [options] packages = - cutlass - cutlass.backend - cutlass.backend.evt - cutlass.backend.evt.backend - cutlass.backend.evt.frontend - cutlass.backend.evt.ir - cutlass.backend.evt.passes - cutlass.backend.utils - cutlass.emit - cutlass.epilogue - cutlass.op - cutlass.utils + cutlass_cppgen + cutlass_cppgen.backend + cutlass_cppgen.backend.evt + cutlass_cppgen.backend.evt.backend + cutlass_cppgen.backend.evt.frontend + cutlass_cppgen.backend.evt.ir + cutlass_cppgen.backend.evt.passes + cutlass_cppgen.backend.utils + cutlass_cppgen.emit + cutlass_cppgen.epilogue + cutlass_cppgen.op + cutlass_cppgen.utils cutlass_library cutlass_library.source pycute package_dir = - cutlass=python/cutlass + cutlass_cppgen=python/cutlass cutlass_library=python/cutlass_library cutlass_library.source=. pycute=python/pycute diff --git a/test/python/cutlass/conv2d/conv2d_problem_sizes.py b/test/python/cutlass/conv2d/conv2d_problem_sizes.py index d16338d9..852c0277 100644 --- a/test/python/cutlass/conv2d/conv2d_problem_sizes.py +++ b/test/python/cutlass/conv2d/conv2d_problem_sizes.py @@ -38,8 +38,8 @@ This file was ported from the C++ version in test/unit/conv/device/conv2d_proble from cutlass_library import ConvMode -import cutlass -from cutlass.shape import Conv2DProblemSize +import cutlass_cppgen +from cutlass_cppgen.shape import Conv2DProblemSize class TestbedConv2dProblemSizes: diff --git a/test/python/cutlass/conv2d/conv2d_sm80.py b/test/python/cutlass/conv2d/conv2d_sm80.py index fd59cbdd..f77a0ec8 100644 --- a/test/python/cutlass/conv2d/conv2d_sm80.py +++ b/test/python/cutlass/conv2d/conv2d_sm80.py @@ -37,13 +37,13 @@ Low-level functionality tests for Conv2d opreations on SM80 import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from conv2d_test_utils import * -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 80 @@ -62,54 +62,54 @@ conv_problems = get_conv_problems() for conv_kind in ["fprop", "wgrad", "dgrad"]: # F16, simt add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="simt", threadblock_shape=[128, 128, 8], warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1]) # F16, tensor op add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) # F16, tensor op, analytic iterator add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic") # F16, tensor op, f32 output add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f32, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, opclass="tensor_op", threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) # F16, tensor op, different tile description add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 64, 32], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]) # F32, simt add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, opclass="simt", threadblock_shape=[128, 128, 8], warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1]) # Tf32, tensorop add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, opclass="tensor_op", threadblock_shape=[128, 128, 16], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8] ) # Split-K add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial", split_k_slices=2) add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel", split_k_slices=5) # Swizzling functor add_test( - Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 64, 32], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4) @@ -120,14 +120,14 @@ for c, tb, stage, inst in zip([2, 1], [3, 2], [[16, 8, 16], [16, 8, 8]]): add_test( - Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=tb, warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels" ) # F16, tensor op, fixed channels for c in [8, 4, 2]: add_test( - Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels" ) @@ -136,7 +136,7 @@ for c in [8, 4, 2]: for activation in ["relu", "leaky_relu"]: for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]): add_test( - Conv2dSm80, cc, "fprop", conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + Conv2dSm80, cc, "fprop", conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, opclass="tensor_op", threadblock_shape=[128, 128, 64], warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode, split_k_slices=split_k_slices, activation=activation) diff --git a/test/python/cutlass/conv2d/conv2d_test_utils.py b/test/python/cutlass/conv2d/conv2d_test_utils.py index 454eb22b..9bc4542c 100644 --- a/test/python/cutlass/conv2d/conv2d_test_utils.py +++ b/test/python/cutlass/conv2d/conv2d_test_utils.py @@ -37,7 +37,7 @@ Utility functions for Conv2d tests. from cutlass_library import SubstituteTemplate import torch -import cutlass +import cutlass_cppgen from cutlass_library import ( ConvKind, ConvMode, @@ -51,8 +51,8 @@ from cutlass_library import ( ShortLayoutTypeNames, SplitKMode, ) -from cutlass.shape import Conv2DProblemSize -from cutlass.utils.datatypes import numpy_type, torch_type +from cutlass_cppgen.shape import Conv2DProblemSize +from cutlass_cppgen.utils.datatypes import numpy_type, torch_type from conv2d_problem_sizes import TestbedConv2dProblemSizes @@ -88,7 +88,7 @@ def get_name_conv2d( :param element_c: data type of operand C :param element_accumulator: data type used in accumulation :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpcodeClass + :type opclass: cutlass_cppgen.OpcodeClass :param threadblock_shape: indexable container of dimensions of threadblock tiles :param stages: number of pipeline stages to use in the kernel :type stages: int @@ -216,7 +216,7 @@ def validate_problem_size(ps, conv_kind, split_k_slices): class Conv2dLauncherFrontend: - def __init__(self, plan: cutlass.Conv2d, seed: int = 80, backend="numpy"): + def __init__(self, plan: cutlass_cppgen.Conv2d, seed: int = 80, backend="numpy"): self.operation = plan self.conv_kind = plan.conv_kind self.seed = seed @@ -233,7 +233,7 @@ class Conv2dLauncherFrontend: self.element_compute = DataType.f32 - if self.dtype_A in [cutlass.DataType.f16, cutlass.DataType.bf16]: + if self.dtype_A in [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.bf16]: self.rand_max = 1 else: self.rand_max = 4 @@ -273,9 +273,9 @@ class Conv2dLauncherFrontend: else: raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.") - if activation == cutlass.backend.epilogue.relu: + if activation == cutlass_cppgen.backend.epilogue.relu: torch_result = torch.nn.functional.relu(torch_result) - elif activation == cutlass.backend.epilogue.leaky_relu: + elif activation == cutlass_cppgen.backend.epilogue.leaky_relu: torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5) return torch_result @@ -345,7 +345,7 @@ def add_test( def run(self): # Create the plan - plan = cutlass.Conv2d( + plan = cutlass_cppgen.Conv2d( kind=conv_kind, element=element, element_accumulator=element_accumulator, @@ -373,9 +373,9 @@ def add_test( if activation != "identity": if activation == "leaky_relu": - plan.activation = (cutlass.epilogue.leaky_relu, 0.5) + plan.activation = (cutlass_cppgen.epilogue.leaky_relu, 0.5) else: - plan.activation = getattr(cutlass.epilogue, activation) + plan.activation = getattr(cutlass_cppgen.epilogue, activation) conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch") diff --git a/test/python/cutlass/emit/pytorch.py b/test/python/cutlass/emit/pytorch.py index 5fe5b1a5..c9d4c52a 100644 --- a/test/python/cutlass/emit/pytorch.py +++ b/test/python/cutlass/emit/pytorch.py @@ -40,9 +40,9 @@ import unittest from cutlass_library import ConvMode -import cutlass +import cutlass_cppgen -if cutlass.utils.datatypes.is_torch_available(): +if cutlass_cppgen.utils.datatypes.is_torch_available(): import torch @@ -95,7 +95,7 @@ def _generate_conv2d_problem(conv_kind, dtype, ps): :type conv_kind: str :param dtype: data type of tensors :param problem_size: the conv2d problem size - :type problem_size: cutlass.shape.Conv2DProblemSize + :type problem_size: cutlass_cppgen.shape.Conv2DProblemSize :return: initialized tensors A, B, C, and D :rtype: list @@ -116,18 +116,18 @@ def _generate_conv2d_problem(conv_kind, dtype, ps): return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes] -@unittest.skipIf(not cutlass.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests') +@unittest.skipIf(not cutlass_cppgen.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests') class PyTorchExtensionTest(unittest.TestCase): def test_gemm(self): random.seed(2023) dtype = torch.float16 - plan = cutlass.op.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + mod = cutlass_cppgen.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) A, B, C, _ = _initialize(dtype, 1024, 256, 512) @@ -154,11 +154,11 @@ class PyTorchExtensionTest(unittest.TestCase): random.seed(2023) dtype = torch.float16 - plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + mod = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) As, Bs, Cs, _ = _generate_problems(dtype, 50) @@ -189,14 +189,14 @@ class PyTorchExtensionTest(unittest.TestCase): torch.manual_seed(2023) dtype = torch.float16 - plan = cutlass.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) plan.activation = "relu" op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - problem_size = cutlass.shape.Conv2DProblemSize( + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( 1, 4, 4, 16, 8, 3, 3, 16, 0, 0, @@ -231,13 +231,13 @@ class PyTorchExtensionTest(unittest.TestCase): def test_conv2d_dgrad(self): torch.manual_seed(2023) dtype = torch.float16 - plan = cutlass.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) + plan = cutlass_cppgen.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - problem_size = cutlass.shape.Conv2DProblemSize( + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( 1, 4, 4, 16, 8, 3, 3, 16, 0, 0, @@ -265,13 +265,13 @@ class PyTorchExtensionTest(unittest.TestCase): def test_conv2d_wgrad(self): torch.manual_seed(2023) dtype = torch.float16 - plan = cutlass.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) + plan = cutlass_cppgen.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: - mod = cutlass.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - problem_size = cutlass.shape.Conv2DProblemSize( + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( 1, 4, 4, 16, 8, 3, 3, 16, 0, 0, diff --git a/test/python/cutlass/evt/evt_compute_sm80_90.py b/test/python/cutlass/evt/evt_compute_sm80_90.py index 43a9c02d..5467469e 100644 --- a/test/python/cutlass/evt/evt_compute_sm80_90.py +++ b/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -36,14 +36,14 @@ Unit test for compute node in SM90 import logging import unittest -import cutlass -from cutlass.backend import * -from cutlass.epilogue import * -from cutlass import swizzle +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen import swizzle from utils.evt_testbed import EVTTestBed, EVTTestCaseBase -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") diff --git a/test/python/cutlass/evt/evt_layout_sm80_90.py b/test/python/cutlass/evt/evt_layout_sm80_90.py index 71e09973..f5a7b7f7 100644 --- a/test/python/cutlass/evt/evt_layout_sm80_90.py +++ b/test/python/cutlass/evt/evt_layout_sm80_90.py @@ -37,13 +37,13 @@ Unit test for store nodes in SM90 import logging import unittest -import cutlass -from cutlass.backend import * -from cutlass.epilogue import * +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * from utils.evt_testbed import EVTTestBed, EVTTestCaseBase -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") diff --git a/test/python/cutlass/evt/evt_load_sm80_90.py b/test/python/cutlass/evt/evt_load_sm80_90.py index 1b1a4fa6..57a5c6bb 100644 --- a/test/python/cutlass/evt/evt_load_sm80_90.py +++ b/test/python/cutlass/evt/evt_load_sm80_90.py @@ -37,13 +37,13 @@ Unit test for load nodes in SM90 import logging import unittest -import cutlass -from cutlass.backend import * -from cutlass.epilogue import * +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * from utils.evt_testbed import EVTTestBed, EVTTestCaseBase -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") diff --git a/test/python/cutlass/evt/evt_mixed_sm80_90.py b/test/python/cutlass/evt/evt_mixed_sm80_90.py index 45392af6..30dc8fe0 100644 --- a/test/python/cutlass/evt/evt_mixed_sm80_90.py +++ b/test/python/cutlass/evt/evt_mixed_sm80_90.py @@ -37,14 +37,14 @@ Unittest for mixed types of nodes in SM90 import logging import unittest -import cutlass -from cutlass.backend import * -from cutlass.epilogue import * -from cutlass.swizzle import ThreadblockSwizzleStreamK +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK from utils.evt_testbed import EVTTestBed, EVTTestCaseBase -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") diff --git a/test/python/cutlass/evt/evt_store_sm80_90.py b/test/python/cutlass/evt/evt_store_sm80_90.py index 9ff3d7d7..b47f11e4 100644 --- a/test/python/cutlass/evt/evt_store_sm80_90.py +++ b/test/python/cutlass/evt/evt_store_sm80_90.py @@ -37,13 +37,13 @@ Unit test for store nodes in SM90 import logging import unittest -import cutlass -from cutlass.backend import * -from cutlass.epilogue import * +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * from utils.evt_testbed import EVTTestBed, EVTTestCaseBase -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") diff --git a/test/python/cutlass/evt/utils/evt_testbed.py b/test/python/cutlass/evt/utils/evt_testbed.py index bd027803..62d375d8 100644 --- a/test/python/cutlass/evt/utils/evt_testbed.py +++ b/test/python/cutlass/evt/utils/evt_testbed.py @@ -37,12 +37,12 @@ Testbed classes of EVT import torch import unittest -import cutlass -from cutlass import Tensor -import cutlass.backend.evt -from cutlass.shape import GemmCoord -from cutlass.utils.datatypes import torch_type -from cutlass.utils.profiler import CUDAEventProfiler +import cutlass_cppgen +from cutlass_cppgen import Tensor +import cutlass_cppgen.backend.evt +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import torch_type +from cutlass_cppgen.utils.profiler import CUDAEventProfiler class EVTReferenceModule: @@ -53,19 +53,19 @@ class EVTReferenceModule: self.epilogue_visitor = epilogue_visitor def run(self, A, B, C, problem_size, alpha, beta, batch=1): - if self.layout_A == cutlass.LayoutType.RowMajor: + if self.layout_A == cutlass_cppgen.LayoutType.RowMajor: A_row = A.view((batch, problem_size.m, problem_size.k)) else: A_col = A.view((batch, problem_size.k, problem_size.m)) A_row = torch.permute(A_col, (0, 2, 1)) - if self.layout_B == cutlass.LayoutType.RowMajor: + if self.layout_B == cutlass_cppgen.LayoutType.RowMajor: B_row = B.view((batch, problem_size.k, problem_size.n)) else: B_col = B.view((batch, problem_size.n, problem_size.k)) B_row = torch.permute(B_col, (0, 2, 1)) - if self.layout_C == cutlass.LayoutType.RowMajor: + if self.layout_C == cutlass_cppgen.LayoutType.RowMajor: C_row = C.view((batch, problem_size.m, problem_size.n)) else: C_col = C.view((batch, problem_size.n, problem_size.m)) @@ -73,7 +73,7 @@ class EVTReferenceModule: out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta - if self.layout_C == cutlass.LayoutType.ColumnMajor: + if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor: out = torch.permute(out_row, (0, 2, 1)) else: out = out_row @@ -102,11 +102,11 @@ class EVTTestBed: """ def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None: self.element = element - layout = cutlass.LayoutType.RowMajor + layout = cutlass_cppgen.LayoutType.RowMajor self.example_inputs = example_inputs # Create the Gemm plan - self.plan = cutlass.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32) + self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32) if "tile_description" in kwargs: self.plan.tile_description = kwargs["tile_description"] @@ -115,7 +115,7 @@ class EVTTestBed: self.plan.swizzling_functor = kwargs["swizzling_functor"] # Compile the epilogue visitor - epilogue_visitor = cutlass.epilogue.trace(evt_fn, example_inputs) + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs) if "epilogue_stages" in kwargs: epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"] self.plan.epilogue_visitor = epilogue_visitor @@ -205,7 +205,7 @@ class EVTTestCaseBase(unittest.TestCase): def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None: super().__init__(methodName) - self.element = cutlass.DataType.f16 + self.element = cutlass_cppgen.DataType.f16 self.l, self.m, self.n, self.k = lmnk self.problem_size = (self.m, self.n, self.k) @@ -214,7 +214,7 @@ class EVTTestCaseBase(unittest.TestCase): def fake_tensor(self, element, shape, stride=None): if stride is None: - return Tensor(element=element, shape=shape, layout_tag=cutlass.LayoutType.RowMajor) + return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor) else: return Tensor(element=element, shape=shape, stride=stride) diff --git a/test/python/cutlass/gemm/gemm_batched.py b/test/python/cutlass/gemm/gemm_batched.py index a4303970..155426ab 100644 --- a/test/python/cutlass/gemm/gemm_batched.py +++ b/test/python/cutlass/gemm/gemm_batched.py @@ -39,13 +39,13 @@ import logging from math import prod import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc import torch from utils import LayoutCombination -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) torch.manual_seed(2023) @@ -101,7 +101,7 @@ class GemmF16Batched(unittest.TestCase): C = initialize(M, N, batch_count if batch_C else (1,)) D = initialize(M, N, batch_count) - plan = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass.DataType.f32) + plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32) plan.run(A, B, C, D, alpha, beta) reference = pytorch_reference(A, B, C, alpha, beta) assert reference.equal(D) diff --git a/test/python/cutlass/gemm/gemm_f16_sm80.py b/test/python/cutlass/gemm/gemm_f16_sm80.py index 4c8ed29e..dbd26951 100644 --- a/test/python/cutlass/gemm/gemm_f16_sm80.py +++ b/test/python/cutlass/gemm/gemm_f16_sm80.py @@ -38,18 +38,18 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 80 -dtype = cutlass.DataType.f16 +dtype = cutlass_cppgen.DataType.f16 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF16Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -58,7 +58,7 @@ class GemmF16Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF16Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -68,61 +68,61 @@ class GemmF16Sm80StreamK(unittest.TestCase): add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) # Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) if __name__ == '__main__': unittest.main() diff --git a/test/python/cutlass/gemm/gemm_f16_sm90.py b/test/python/cutlass/gemm/gemm_f16_sm90.py index 445a096f..61aa295b 100644 --- a/test/python/cutlass/gemm/gemm_f16_sm90.py +++ b/test/python/cutlass/gemm/gemm_f16_sm90.py @@ -38,18 +38,18 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 90 -dtype = cutlass.DataType.f16 +dtype = cutlass_cppgen.DataType.f16 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF16Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -60,87 +60,87 @@ class GemmF16Sm90(unittest.TestCase): add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype, warp_count=None, compilation_modes=['nvcc']) -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) # Tests with 1x1x1 clusters add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1]) -add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) -add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) -add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) +add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) # Tests with different cluster shapes add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f16, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 4, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 4, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, cluster_shape=[4, 1, 1]) -add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, - element_accumulator=cutlass.DataType.f32, cluster_shape=[4, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1]) # Tests for different schedule modes add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4], - element_output=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, - opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) + element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) add_test_schedule( cluster_shape=[1, 1, 1], - kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized ) add_test_schedule( cluster_shape=[1, 1, 1], - kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative, - epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative ) add_test_schedule( cluster_shape=[2, 1, 1], - kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized ) add_test_schedule( cluster_shape=[2, 1, 1], - kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative, - epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative ) # Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) -add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8]) -add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8]) -add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8]) -add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8]) -add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8]) +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) +add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8]) +add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8]) +add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8]) +add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8]) +add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8]) # Tests with void-C kernels -add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, - element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, - cluster_shape=[2, 1, 1], element_C=cutlass.DataType.void) +add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, + cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void) if __name__ == '__main__': unittest.main() diff --git a/test/python/cutlass/gemm/gemm_f32_sm80.py b/test/python/cutlass/gemm/gemm_f32_sm80.py index c5b85170..bf662b92 100644 --- a/test/python/cutlass/gemm/gemm_f32_sm80.py +++ b/test/python/cutlass/gemm/gemm_f32_sm80.py @@ -38,19 +38,19 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 80 -dtype = cutlass.DataType.f32 +dtype = cutlass_cppgen.DataType.f32 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF32Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -59,7 +59,7 @@ class GemmF32Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF32Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -70,7 +70,7 @@ class GemmF32Sm80StreamK(unittest.TestCase): add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) @@ -81,7 +81,7 @@ add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) # Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) @@ -95,7 +95,7 @@ add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) diff --git a/test/python/cutlass/gemm/gemm_f64_sm80.py b/test/python/cutlass/gemm/gemm_f64_sm80.py index f238890e..3075ddf7 100644 --- a/test/python/cutlass/gemm/gemm_f64_sm80.py +++ b/test/python/cutlass/gemm/gemm_f64_sm80.py @@ -38,19 +38,19 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 80 -dtype = cutlass.DataType.f64 +dtype = cutlass_cppgen.DataType.f64 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF64Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -59,7 +59,7 @@ class GemmF64Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF64Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -70,7 +70,7 @@ class GemmF64Sm80StreamK(unittest.TestCase): add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) @@ -80,7 +80,7 @@ add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) # Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) @@ -94,7 +94,7 @@ add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) diff --git a/test/python/cutlass/gemm/gemm_f64_sm90.py b/test/python/cutlass/gemm/gemm_f64_sm90.py index d0d0238d..9bf36fc7 100644 --- a/test/python/cutlass/gemm/gemm_f64_sm90.py +++ b/test/python/cutlass/gemm/gemm_f64_sm90.py @@ -38,19 +38,19 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 90 -dtype = cutlass.DataType.f64 +dtype = cutlass_cppgen.DataType.f64 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF64Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -61,10 +61,10 @@ class GemmF64Sm90(unittest.TestCase): add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc']) -add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) -add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) -add_test_specialized( opclass=cutlass.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) -add_test_specialized( opclass=cutlass.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_f8_sm90.py b/test/python/cutlass/gemm/gemm_f8_sm90.py index 5735e36c..fef6d457 100644 --- a/test/python/cutlass/gemm/gemm_f8_sm90.py +++ b/test/python/cutlass/gemm/gemm_f8_sm90.py @@ -38,19 +38,19 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 90 -dtype = cutlass.DataType.e4m3 +dtype = cutlass_cppgen.DataType.e4m3 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF8E4M3Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -60,38 +60,38 @@ class GemmF8E4M3Sm90(unittest.TestCase): add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc']) -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) # Test with 1x1x1 clusters -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, - element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) # Tests with different cluster shapes -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, - element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, - element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) # Tests with warp-specialized ping-pong schedule -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.e4m3, - element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, - kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) # Tests for SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.e4m3, - element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) # # Add a test for E5M2 # -dtype = cutlass.DataType.e5m2 +dtype = cutlass_cppgen.DataType.e5m2 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmF8E5M2Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -101,11 +101,11 @@ class GemmF8E5M2Sm90(unittest.TestCase): add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc']) -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) # Tests with 1x1x1 clusters add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype, - element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_mixed_sm80.py b/test/python/cutlass/gemm/gemm_mixed_sm80.py index 857acd83..0a002a5f 100644 --- a/test/python/cutlass/gemm/gemm_mixed_sm80.py +++ b/test/python/cutlass/gemm/gemm_mixed_sm80.py @@ -38,19 +38,19 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 80 -dtype =cutlass.DataType.f16 +dtype =cutlass_cppgen.DataType.f16 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmMixedSm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -59,16 +59,16 @@ class GemmMixedSm80(unittest.TestCase): add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1], - opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], - warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32) + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32) # Test with upcast on A -add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) -add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) # Test with upcast on B -add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) -add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_s8_sm80.py b/test/python/cutlass/gemm/gemm_s8_sm80.py index 38dd307d..e226e236 100644 --- a/test/python/cutlass/gemm/gemm_s8_sm80.py +++ b/test/python/cutlass/gemm/gemm_s8_sm80.py @@ -38,19 +38,19 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 80 -dtype = cutlass.DataType.s8 +dtype = cutlass_cppgen.DataType.s8 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmS8Sm80(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -59,7 +59,7 @@ class GemmS8Sm80(unittest.TestCase): @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmS8Sm80StreamK(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -70,33 +70,33 @@ class GemmS8Sm80StreamK(unittest.TestCase): add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) # Tests using SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_s8_sm90.py b/test/python/cutlass/gemm/gemm_s8_sm90.py index 2cfeadb8..ec0101f7 100644 --- a/test/python/cutlass/gemm/gemm_s8_sm90.py +++ b/test/python/cutlass/gemm/gemm_s8_sm90.py @@ -38,19 +38,19 @@ from functools import partial import logging import unittest -import cutlass -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc from utils import LayoutCombination, add_test_gemm -cutlass.set_log_level(logging.WARNING) +cutlass_cppgen.set_log_level(logging.WARNING) cc = 90 -dtype = cutlass.DataType.s8 +dtype = cutlass_cppgen.DataType.s8 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') -@unittest.skipIf(cutlass.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') class GemmS8Sm90(unittest.TestCase): """ Wrapper class to which tests will be added dynamically in __main__ @@ -60,38 +60,38 @@ class GemmS8Sm90(unittest.TestCase): add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc']) -add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) # Tests with 1x1x1 clusters -add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) # Tests with different cluster shapes -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) # Tests with warp-specialized ping-pong schedule -add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, - kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, - epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) # Tests for SIMT -add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, - element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_testbed.py b/test/python/cutlass/gemm/gemm_testbed.py index d4631650..50eb7a9b 100644 --- a/test/python/cutlass/gemm/gemm_testbed.py +++ b/test/python/cutlass/gemm/gemm_testbed.py @@ -47,11 +47,11 @@ from cutlass_library import ( SwizzlingFunctor ) -from cutlass.backend import compiler -from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal -from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation -from cutlass.shape import GemmCoord, MatrixCoord -from cutlass.utils.datatypes import torch_type +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation +from cutlass_cppgen.shape import GemmCoord, MatrixCoord +from cutlass_cppgen.utils.datatypes import torch_type class GemmUniversalLauncher: @@ -153,7 +153,7 @@ class GemmUniversalLauncher: else: data_cutlass = data_ref.transpose(-1, -2).contiguous() - data_cutlass = data_cutlass.to("cuda") + data_cutlass = data_cutlass_cppgen.to("cuda") # As of this writing, few operations in PyTorch are supported with FP8 data. # Thus, we perform computation in FP32 for FP8 reference checks. diff --git a/test/python/cutlass/gemm/utils.py b/test/python/cutlass/gemm/utils.py index 6ec92fec..28bba3e9 100644 --- a/test/python/cutlass/gemm/utils.py +++ b/test/python/cutlass/gemm/utils.py @@ -32,7 +32,7 @@ from cutlass_library import SubstituteTemplate -import cutlass +import cutlass_cppgen from cutlass_library import ( DataTypeNames, EpilogueScheduleSuffixes, @@ -42,7 +42,7 @@ from cutlass_library import ( ShortDataTypeNames, ShortLayoutTypeNames ) -from cutlass.backend import library +from cutlass_cppgen.backend import library from gemm_testbed import test_all_gemm @@ -107,11 +107,11 @@ def get_name( :param arch: compute capability of kernel being generated :type arch: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpcodeClass + :type opclass: cutlass_cppgen.OpcodeClass :param kernel_schedule: kernel_schedule type - :type kernel_schedule: cutlass.KernelScheduleType + :type kernel_schedule: cutlass_cppgen.KernelScheduleType :param epilogue_schedule: epilogue_schedule type - :type epilogue_schedule: cutlass.EpilogueScheduleType + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType :param suffix: additional string to add to the suffix of the name :type suffix: str @@ -175,15 +175,15 @@ def add_test_gemm( :param cc: compute capability to compile for :type cc: int :param element: data type of A and B operands - :type element: cutlass.DataType.f16 + :type element: cutlass_cppgen.DataType.f16 :param layouts: layouts of A, B, and C operands :type layouts: list or tuple :param alignments: alingments of A, B, and C operands :type alignments: list or tuple :param element_output: data type of the output element - :type element_output: cutlass.DataType + :type element_output: cutlass_cppgen.DataType :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType + :type element_accumulator: cutlass_cppgen.DataType :param cluster_shape: dimensions of clusters :type cluster_shape: list or tuple :param threadblock_shape: dimensions of threadblock tiles @@ -193,20 +193,20 @@ def add_test_gemm( :param stages: number of pipeline stages to use in the kernel :type stages: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpcodeClass + :type opclass: cutlass_cppgen.OpcodeClass :param swizzle: threadblock swizzling functor :param kernel_schedule: kernel schedule to use - :type kernel_schedule: cutlass.KernelScheduleType + :type kernel_schedule: cutlass_cppgen.KernelScheduleType :param epilogue_schedule: epilogue schedule to use - :type epilogue_schedule: cutlass.EpilogueScheduleType + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') :type compilation_modes: list, :param element_A: data type of operand A. If set, overrides ``element`` - :type element_A: cutlass.DataType + :type element_A: cutlass_cppgen.DataType :param element_B: data type of operand B. If set, overrides ``element`` - :type element_B: cutlass.DataType + :type element_B: cutlass_cppgen.DataType :param element_C: data type of operand C. If set, overrides ``element`` - :type element_C: cutlass.DataType + :type element_C: cutlass_cppgen.DataType """ if element_A is None: @@ -230,7 +230,7 @@ def add_test_gemm( layout_A, layout_B, layout_C = layouts alignment_A, alignment_B, alignment_C = alignments - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, element_D=element_output, layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, element_accumulator=element_accumulator, diff --git a/test/python/cutlass/installation.py b/test/python/cutlass/installation.py index 6f05e5ac..f550c394 100644 --- a/test/python/cutlass/installation.py +++ b/test/python/cutlass/installation.py @@ -37,7 +37,7 @@ Tests for a successful installation of the CUTLASS Python interface import os import unittest -import cutlass +import cutlass_cppgen import cutlass_library @@ -48,7 +48,7 @@ class InstallationTest(unittest.TestCase): """ src_file = 'include/cutlass/cutlass.h' library_file = os.path.join(cutlass_library.source_path, src_file) - cutlass_file = os.path.join(cutlass.CUTLASS_PATH, src_file) + cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file) assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded." assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded." diff --git a/test/python/cutlass/interface/conv2d_interface.py b/test/python/cutlass/interface/conv2d_interface.py index d0f13e04..2b5d46d4 100644 --- a/test/python/cutlass/interface/conv2d_interface.py +++ b/test/python/cutlass/interface/conv2d_interface.py @@ -37,9 +37,9 @@ Tests the high-level Conv2d interface from math import ceil import unittest -import cutlass -import cutlass.utils.datatypes as datatypes -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc from utils import ExpectException import os @@ -62,7 +62,7 @@ class Conv2dEquivalence: self.conv_kind = conv_kind - self.plan = cutlass.op.Conv2d( + self.plan = cutlass_cppgen.op.Conv2d( kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, element_D=element_D, element_accumulator=element_accumulator) @@ -75,7 +75,7 @@ class Conv2dEquivalence: Compares whether two plans are equal :param other_plan: plan to compare against the default Conv2d - :type other_plan: cutlass.op.Conv2d + :type other_plan: cutlass_cppgen.op.Conv2d :return: whether `other_plan` is equivalent to `self.plan` :rtype: bool @@ -95,14 +95,14 @@ class Conv2dEquivalence: return # Test when specifying all parameters - plan_other = cutlass.op.Conv2d( + plan_other = cutlass_cppgen.op.Conv2d( kind=self.conv_kind, element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator) assert self._plans_equal(plan_other) # Test when specifying all parameters but A - plan_other = cutlass.op.Conv2d( + plan_other = cutlass_cppgen.op.Conv2d( kind=self.conv_kind, element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, @@ -110,7 +110,7 @@ class Conv2dEquivalence: assert self._plans_equal(plan_other) # Test when specifying all parameters but A and B as tensors using generic element and output - plan_other = cutlass.op.Conv2d( + plan_other = cutlass_cppgen.op.Conv2d( kind=self.conv_kind, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, @@ -119,7 +119,7 @@ class Conv2dEquivalence: # Test without explicit accumulator. Only run if the type of C and the accumulator are equal if self.element_C == self.element_accumulator: - plan_other = cutlass.op.Conv2d( + plan_other = cutlass_cppgen.op.Conv2d( kind=self.conv_kind, element_C=self.element_C, element_D=self.element_D, @@ -129,7 +129,7 @@ class Conv2dEquivalence: # Test with only the generic types. Only rune if the types of A, B, C, and D are the same if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D and self.element_A == self.element_accumulator): - plan_other = cutlass.op.Conv2d(kind=self.conv_kind, element=self.element_A) + plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A) assert self._plans_equal(plan_other) def numpy_test(self): @@ -179,26 +179,26 @@ class Conv2dEquivalence: def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): # Test when specifying all parameters via tensors - plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) assert self._plans_equal(plan_np) # Test when specifying all parameters but A as tensors - plan_np = cutlass.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) assert self._plans_equal(plan_np) # Test when specifying all parameters but A and B as tensors and using generic element and output if type_A == type_B: - plan_np = cutlass.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) assert self._plans_equal(plan_np) # Test without explicit accumulator. Only run if the type of C and the accumulator. if type_C == type_accum: - plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) assert self._plans_equal(plan_np) # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): - plan_np = cutlass.op.Conv2d(kind=self.conv_kind, element=type_A) + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A) assert self._plans_equal(plan_np) def test_all(self): @@ -218,8 +218,8 @@ class ConvEquivalenceTest(unittest.TestCase): pass type2alignment = { - cutlass.DataType.f16: 8, - cutlass.DataType.f32: 4 + cutlass_cppgen.DataType.f16: 8, + cutlass_cppgen.DataType.f32: 4 } def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): @@ -241,11 +241,11 @@ def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accu for conv_kind in ["fprop", "wgrad", "dgrad"]: for types in [ - [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16], - [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32], - [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f16], - [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32], - [cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32] + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32] ]: add_test(conv_kind, types[0], types[1], types[2], types[3], types[4]) @@ -260,7 +260,7 @@ class Conv2dErrorTests(unittest.TestCase): """ Tests case in which the alignment specified is unsupported """ - plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) @@ -269,7 +269,7 @@ class Conv2dErrorTests(unittest.TestCase): """ Tests scenarios in which an invalid tile description is provided for a given CC """ - plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) td = plan.tile_descriptions()[0] td.threadblock_shape=[17, 32, 5] diff --git a/test/python/cutlass/interface/evt_interface.py b/test/python/cutlass/interface/evt_interface.py index f2668d70..e7d67f4d 100644 --- a/test/python/cutlass/interface/evt_interface.py +++ b/test/python/cutlass/interface/evt_interface.py @@ -37,10 +37,10 @@ Test the EVT interface import numpy as np import unittest -import cutlass -from cutlass import LayoutType, Tensor -from cutlass.backend.utils.device import device_cc -from cutlass.epilogue import reshape, permute +import cutlass_cppgen +from cutlass_cppgen import LayoutType, Tensor +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import reshape, permute from utils import ExpectException @@ -69,7 +69,7 @@ class EVTErrorTests(unittest.TestCase): "SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, " "but the variable 'D' is not found in the return values.", True): - cutlass.epilogue.trace(evt_root_not_d, example_tensors) + cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors) def test_no_accum(self): """ @@ -86,7 +86,7 @@ class EVTErrorTests(unittest.TestCase): } with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True): - cutlass.epilogue.trace(evt_no_accum, example_tensors) + cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors) @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size") def test_too_much_shared_memory(self): @@ -124,10 +124,10 @@ class EVTErrorTests(unittest.TestCase): "D": self.fake_tensor(np.float16, (6, 512, 512)) } - epilogue_visitor = cutlass.epilogue.trace(evt_too_much_shared_memory, example_tensors) + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors) - plan = cutlass.op.Gemm( - element=np.float16, layout=cutlass.LayoutType.RowMajor, + plan = cutlass_cppgen.op.Gemm( + element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor, element_accumulator=np.float32 ) @@ -155,7 +155,7 @@ class EVTErrorTests(unittest.TestCase): } with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True): - cutlass.epilogue.trace(evt_redefine, example_tensors) + cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors) def evt_undefine(accum, alpha): F = accum + C @@ -170,7 +170,7 @@ class EVTErrorTests(unittest.TestCase): } with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True): - cutlass.epilogue.trace(evt_undefine, example_tensors) + cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors) def test_missing_example_tensor(self): """ @@ -186,7 +186,7 @@ class EVTErrorTests(unittest.TestCase): } with ExpectException(True, "RuntimeError: Example input for D is not provided.", True): - cutlass.epilogue.trace(evt_missing_example_tensor, example_tensors) + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) example_tensors = { "accum": self.fake_tensor(np.float16, (6, 512, 512)), @@ -194,7 +194,7 @@ class EVTErrorTests(unittest.TestCase): } with ExpectException(True, "RuntimeError: Example input for C is not provided.", True): - cutlass.epilogue.trace(evt_missing_example_tensor, example_tensors) + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) def test_return_expression(self): """ @@ -209,7 +209,7 @@ class EVTErrorTests(unittest.TestCase): } with ExpectException(True, "SyntaxError: Return value cannot be an expression", True): - cutlass.epilogue.trace(evt_return_expr, example_tensors) + cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors) def test_incompatible_shape(self): """ @@ -227,7 +227,7 @@ class EVTErrorTests(unittest.TestCase): with ExpectException(True, "RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True): - cutlass.epilogue.trace(evt_incompatible_shape, example_tensors) + cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors) def test_no_matching_impl(self): def evt_no_matching_impl(accum, bias): @@ -241,7 +241,7 @@ class EVTErrorTests(unittest.TestCase): } with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True): - cutlass.epilogue.trace(evt_no_matching_impl, example_tensors) + cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors) # # Helper functions # diff --git a/test/python/cutlass/interface/gemm_interface.py b/test/python/cutlass/interface/gemm_interface.py index 85ef228d..723c4c07 100644 --- a/test/python/cutlass/interface/gemm_interface.py +++ b/test/python/cutlass/interface/gemm_interface.py @@ -37,9 +37,9 @@ Tests the high-level GEMM interface from math import ceil import unittest -import cutlass -import cutlass.utils.datatypes as datatypes -from cutlass.backend.utils.device import device_cc +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc from utils import ExpectException @@ -60,7 +60,7 @@ class GemmEquivalence: self.alignment_A = alignment_A self.alignment_B = alignment_B self.alignment_C = alignment_C - self.plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, + self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, element_D=element_D, element_accumulator=element_accumulator, layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) @@ -70,7 +70,7 @@ class GemmEquivalence: Compares whether two plans are equal :param other_plan: plan to compare against the default GEMM - :type other_plan: cutlass.op.Gemm + :type other_plan: cutlass_cppgen.op.Gemm :return: whether `other_plan` is equivalent to `self.plan` :rtype: bool @@ -89,13 +89,13 @@ class GemmEquivalence: return # Test when specifying all parameters - plan_other = cutlass.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) assert self._plans_equal(plan_other) # Test when specifying all parameters but A - plan_other = cutlass.op.Gemm(element_B=self.element_B, element_C=self.element_C, + plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, layout_B=self.layout_B, layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) @@ -104,13 +104,13 @@ class GemmEquivalence: # Test when specifying all parameters but A and B as tensors and using generic element and output # Only run this test if the layouts and types for A and B are equal. if self.element_A == self.element_B and self.layout_A == self.layout_B: - plan_other = cutlass.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, + plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) assert self._plans_equal(plan_other) # Test without explicit accumulator. Only run if the type of C and the accumulator. if self.element_C == self.element_accumulator: - plan_other = cutlass.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) assert self._plans_equal(plan_other) @@ -119,7 +119,7 @@ class GemmEquivalence: if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D and self.element_A == self.element_accumulator and self.layout_A == self.layout_B and self.layout_A == self.layout_C): - plan_other = cutlass.op.Gemm(element=self.element_A, layout=self.layout_A) + plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A) assert self._plans_equal(plan_other) def numpy_test(self): @@ -137,8 +137,8 @@ class GemmEquivalence: type_accum = datatypes.numpy_type(self.element_accumulator) layout_to_order = { - cutlass.LayoutType.RowMajor: 'C', - cutlass.LayoutType.ColumnMajor: 'F' + cutlass_cppgen.LayoutType.RowMajor: 'C', + cutlass_cppgen.LayoutType.ColumnMajor: 'F' } size = (2, 2) A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) @@ -147,28 +147,28 @@ class GemmEquivalence: D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) # Test when specifying all parameters via tensors - plan_np = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) assert self._plans_equal(plan_np) # Test when specifying all parameters but A as tensors - plan_np = cutlass.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) + plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) assert self._plans_equal(plan_np) # Test when specifying all parameters but A and B as tensors and using generic element and output # Only run this test if the layouts and types for A and B are equal. if type_A == type_B and self.layout_A == self.layout_B: - plan_np = cutlass.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) + plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) assert self._plans_equal(plan_np) # Test without explicit accumulator. Only run if the type of C and the accumulator. if type_C == type_accum: - plan_np = cutlass.op.Gemm(A=A, B=B, C=C, D=D) + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D) assert self._plans_equal(plan_np) # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and self.layout_A == self.layout_B and self.layout_A == self.layout_C): - plan_np = cutlass.op.Gemm(element=type_A, layout=self.layout_A) + plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A) assert self._plans_equal(plan_np) def test_all(self): @@ -186,36 +186,36 @@ class GemmEquivalenceTest(unittest.TestCase): @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): gemm_eq = GemmEquivalence( - element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, - layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor, + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, alignment_A=8, alignment_B=8, alignment_C=8) gemm_eq.test_all() @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): gemm_eq = GemmEquivalence( - element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, - layout_A=cutlass.LayoutType.ColumnMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.ColumnMajor, + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, + layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor, alignment_A=8, alignment_B=8, alignment_C=8) gemm_eq.test_all() @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): gemm_eq = GemmEquivalence( - element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, - element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, - layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor, + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, alignment_A=8, alignment_B=8, alignment_C=8) gemm_eq.test_all() @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): gemm_eq = GemmEquivalence( - element_A=cutlass.DataType.f64, element_B=cutlass.DataType.f64, element_C=cutlass.DataType.f64, - element_D=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, - layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor, layout_C=cutlass.LayoutType.RowMajor, + element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64, + element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, alignment_A=1, alignment_B=1, alignment_C=1) gemm_eq.test_all() @@ -229,7 +229,7 @@ class GemmErrorTests(unittest.TestCase): """ Tests case in which the alignment specified is unsupported """ - plan = cutlass.op.Gemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) @@ -242,13 +242,13 @@ class GemmErrorTests(unittest.TestCase): # F64 Tensor Core operations are only avaiable on devices with CC >= 80 supports_tensorop_f64 = cc >= 80 - plan = cutlass.op.Gemm(cc=cc, element=cutlass.DataType.f64, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' with ExpectException(not supports_tensorop_f64, error_msg): - plan.opclass = cutlass.OpcodeClass.TensorOp + plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp - expected_opclass = cutlass.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass.OpcodeClass.Simt + expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") @@ -256,25 +256,25 @@ class GemmErrorTests(unittest.TestCase): """ Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) """ - plan = cutlass.op.Gemm( element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) - assert plan.opclass == cutlass.OpcodeClass.TensorOp + plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp # Ensure that all tile descriptions have opclass of TensorOp for td in plan.tile_descriptions(): - assert td.math_instruction.opcode_class == cutlass.OpcodeClass.TensorOp + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp - plan.opclass = cutlass.OpcodeClass.Simt + plan.opclass = cutlass_cppgen.OpcodeClass.Simt # Ensure that all tile descriptions have opclass of Simt for td in plan.tile_descriptions(): - assert td.math_instruction.opcode_class == cutlass.OpcodeClass.Simt + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt def test_invalid_tile_description(self): """ Tests scenarios in which an invalid tile description is provided for a given CC """ cc = device_cc() - plan = cutlass.op.Gemm(cc=cc, element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) td = plan.tile_descriptions()[0] stages = td.stages @@ -292,8 +292,8 @@ class GemmErrorTests(unittest.TestCase): original_kschedule = td.kernel_schedule original_eschedule = td.epilogue_schedule with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): - td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong - td.epilogue_schedule = cutlass.EpilogueScheduleType.NoSmemWarpSpecialized + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized td.stages = 3 plan.construct(td) @@ -317,24 +317,24 @@ class GemmErrorTests(unittest.TestCase): td.cluster_shape = cluster_shape with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): - td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong - td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) with ExpectException(True, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): - td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong - td.epilogue_schedule = cutlass.EpilogueScheduleType.ScheduleAuto + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto plan.construct(td) with ExpectException(True, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): - td.kernel_schedule = cutlass.KernelScheduleType.ScheduleAuto - td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): - td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative - td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative - td.tile_scheduler = cutlass.TileSchedulerType.StreamK + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative + td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK plan.construct(td) # Ensure that all returned tile descriptions are unique diff --git a/test/unit/core/complex.cu b/test/unit/core/complex.cu index c065c494..cbee05a4 100644 --- a/test/unit/core/complex.cu +++ b/test/unit/core/complex.cu @@ -33,7 +33,8 @@ */ #include -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(complex) #include "../common/cutlass_unit_test.h" diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu index a956500c..af866bc5 100644 --- a/test/unit/cute/ampere/cooperative_gemm.cu +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -609,6 +609,8 @@ TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f32_MMA) { test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) + TEST(SM89_CuTe_Ada, CooperativeGemm_e4m3e4m3f16_MMA) { using TA = cutlass::float_e4m3_t; using TB = cutlass::float_e4m3_t; @@ -680,3 +682,5 @@ TEST(SM89_CuTe_Ada, CooperativeGemm_e5m2e5m2f16_MMA) { test_cooperative_gemm_col_major_layout(shape_mnk, tiled_mma); } + +#endif diff --git a/test/unit/cute/turing/movm.cu b/test/unit/cute/turing/movm.cu index 932c3e7f..a72c030b 100644 --- a/test/unit/cute/turing/movm.cu +++ b/test/unit/cute/turing/movm.cu @@ -44,7 +44,7 @@ __global__ void movm_test_device(uint16_t* g_in, uint16_t* g_out) { int tid = threadIdx.x; - + // load input gmem -> register uint32_t reg = reinterpret_cast(g_in)[tid]; @@ -128,7 +128,7 @@ TEST(SM75_CuTe_Turing, Movm) // // CuTe MOVM // - + { thrust::device_vector d_out(count); diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index bd6b0629..3b474b0a 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -345,6 +345,13 @@ cutlass_test_unit_gemm_device_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_group_gemm_pingpong.cu ) +# Blockwise Gemm test +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm90_blockwise + + sm90_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu +) + # Sparse tests # Sparse kernels trigger an ICE in gcc 7.5 if (NOT (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) @@ -801,7 +808,7 @@ cutlass_test_unit_gemm_device_add_executable( hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu ) -if (NOT CUTLASS_NVCC_ARCHS MATCHES 101|101a|101f|103|103a|103f) +if (NOT CUTLASS_NVCC_ARCHS MATCHES 100f|101|101a|101f|103|103a|103f) cutlass_test_unit_gemm_device_add_executable( cutlass_test_unit_gemm_device_blas3_gaussian @@ -947,6 +954,219 @@ cutlass_test_unit_gemm_device_add_executable( sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu ) +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_runtime_datatype_alignx_sm100 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype_alignx.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_alignx_sm100 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + sm100_gemm_f8_f8_f8_tensor_op_f32_alignx.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_alignx_streamK_sm100 + + # setting batch size fo 1 to control memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm100_gemm_f8_f8_f8_tensor_op_f32_alignx_streamK.cu +) + +endif() + +if (CUTLASS_NVCC_ARCHS MATCHES 103a|103f) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_1sm + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_1sm.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_f4_tensorop_sm103_nosmem + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_tensor_op_f32_nosmem.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_2sm + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_2sm.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_group_1sm_128x128 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x128.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_group_1sm_128x192 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x192.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_group_2sm_256x192 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x192.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_group_2sm_256x256 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x256.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_ptr_array_1sm_128x128 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x128.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_ptr_array_1sm_128x192 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x192.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_ptr_array_2sm_256x192 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x192.cu +) + +cutlass_test_unit_gemm_device_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_ptr_array_2sm_256x256 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x256.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_streamk + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_stream_k.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_2sm_256x256 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x256.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_2sm_256x192 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x192.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_2sm_256x128 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x128.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_1sm_128x128 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x128.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_1sm_128x192 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x192.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm103_1sm_128x256 + + # No batching of source to control compiler memory usage + BATCH_SOURCES ON + BATCH_SIZE 1 + + sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x256.cu +) + endif() diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 731076f0..89755dd7 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -1219,6 +1219,33 @@ struct HostCollectiveMainloop +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + // // Block Scaled Structured Sparse Gemm Input Operands : A_compressed, B, metadata, scalefactorA, scalefactorB // diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index 0c8cc2c0..d106a53d 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -830,6 +830,32 @@ struct HostCollectiveMainloop +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; template struct HostCollectiveDefaultEpilogue { diff --git a/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu index 043e0740..8bfcd7ed 100644 --- a/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu +++ b/test/unit/gemm/device/sm100_gemm_f4_f4_f32_tensor_op_f32_group_gemm.cu @@ -59,6 +59,56 @@ using namespace cute; #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_group_nosmem, 512x256x256_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using MmaTileShape = cute::Shape<_128,_128,_256>; + using ClusterShape = Shape<_4,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 16 / sizeof(ElementC), + ElementD, LayoutC *, 16 / sizeof(ElementC), + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA *, 32, + MmaTypePairB, LayoutB *, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + TEST(SM100_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_group, 512x256x256_4x2x1) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -209,6 +259,56 @@ TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_1sm_f32_group, 256x512x256_ EXPECT_TRUE(pass); } +TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_2sm_f32_group_nosmem, 256x256x256_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementA = cutlass::float_e2m1_t; + using ElementB = cutlass::float_e2m1_t; + using ElementC = float; + using ElementD = float; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementSF = cutlass::float_ue8m0_t; + using MmaTypePairA = cute::tuple; + using MmaTypePairB = cute::tuple; + + using MmaTileShape = cute::Shape<_256,_128,_256>; + using ClusterShape = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, 4, + ElementD, LayoutC *, 4, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, + MmaTypePairA, LayoutA *, 32, + MmaTypePairB, LayoutB *, 32, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + auto pass = test::gemm::device::TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + TEST(SM100Only_Device_Gemm_e2m1t_e2m1n_f32n_tensorop_2sm_f32_group, 256x256x256_2x2x1) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_alignx.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_alignx.cu new file mode 100644 index 00000000..19df6a52 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_alignx.cu @@ -0,0 +1,209 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e4m3n_e4m3t_e4m3t_tensorop_1sm_f32_align4, 64x128x64_1x1x1) { + using MmaTileShape = Shape<_64,_128,_64>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3n_e4m3t_e4m3t_tensorop_1sm_f32_align4, 128x128x64_1x1x1) { + using MmaTileShape = Shape<_128,_128,_64>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align4, 64x128x128_1x1x1) { + using MmaTileShape = Shape<_64,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, 4, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_1sm_f32_align4, 128x128x128_1x1x1) { + using MmaTileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, 4, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestSmall(1.0, 0.5); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_alignx_streamK.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_alignx_streamK.cu new file mode 100644 index 00000000..9d254073 --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_alignx_streamK.cu @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e4m3n_e4m3t_e4m3t_tensorop_1sm_f32_align4_StreamK, 64x128x64_1x1x1) { + using MmaTileShape = Shape<_64,_128,_64>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestSmall(1.0, 0.5, CheckEquality::RELATIVE, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, {256 + 4}); + EXPECT_TRUE(pass); +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu index 9ce8817e..0e4c4cda 100644 --- a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu @@ -303,6 +303,7 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256 } + TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256x128x128_2x1x1_64x64x64_scale) { bool passed = groupwise_test( @@ -317,4 +318,5 @@ TEST(SM100_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_2sm_f32_align16_blockwise, 256 } + #endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype_alignx.cu b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype_alignx.cu new file mode 100644 index 00000000..fe3a313c --- /dev/null +++ b/test/unit/gemm/device/sm100_gemm_f8_f8_f8_tensor_op_f32_runtime_datatype_alignx.cu @@ -0,0 +1,213 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype_align8, 64x128x128_1x1x1) { + using MmaTileShape = Shape<_64,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 8, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 8, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 8, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 8, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype_align8, 128x128x128_1x1x1) { + using MmaTileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 8, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 8, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 8, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 8, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype_align4, 64x128x128_1x1x1) { + using MmaTileShape = Shape<_64,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 4, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 4, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); + EXPECT_TRUE(pass); + +} + +TEST(SM100_Device_Gemm_e5m2t_e4m3n_e4m3t_tensorop_1sm_f32_runtime_datatype_align4, 128x128x128_1x1x1) { + using MmaTileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_1,_1,_1>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::float_e4m3_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::RowMajor, 4, + cutlass::type_erased_dynamic_float8_t, cutlass::layout::ColumnMajor, 4, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = TestRuntimeDataTypeSmall(cute::UMMA::MXF8F6F4Format::E5M2, cute::UMMA::MXF8F6F4Format::E4M3); + EXPECT_TRUE(pass); + +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/CMakeLists.txt b/test/unit/gemm/device/sm100_tensorop_gemm/CMakeLists.txt index fbe15b5e..00e54c9a 100644 --- a/test/unit/gemm/device/sm100_tensorop_gemm/CMakeLists.txt +++ b/test/unit/gemm/device/sm100_tensorop_gemm/CMakeLists.txt @@ -44,6 +44,7 @@ cutlass_test_unit_gemm_device_add_executable( f16_f16_void_f32.cu f16_f16_f16_f16_fusion.cu + f16_f16_void_f32_narrow_mma_n.cu ) cutlass_test_unit_gemm_device_add_executable( @@ -54,6 +55,7 @@ cutlass_test_unit_gemm_device_add_executable( f8_f8_void_f32.cu f8_f8_f16_f8_fusion.cu + f8_f8_void_bf16_narrow_mma_n.cu ) cutlass_test_unit_gemm_device_add_executable( diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu b/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu new file mode 100644 index 00000000..f79837c2 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/f16_f16_void_f32_narrow_mma_n.cu @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// FP16T x FP16N -> FP32 with 64x8x16 MMA-1CTA +TEST(SM100Only_Device_Gemm_f16t_f16n_void_f32n_tensor_op_f32, 64x8x64_2x2x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_8,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP16N x FP16T -> FP32 with 64x8x16 MMA-1CTA +TEST(SM100Only_Device_Gemm_f16n_f16t_void_f32n_tensor_op_f32, 64x8x64_2x2x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_8,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP16N x FP16T -> FP32 with 128x8x16 MMA-1CTA +TEST(SM100Only_Device_Gemm_f16n_f16t_void_f32n_tensor_op_f32, 128x8x64_2x2x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_8,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP16T x FP16N -> FP32 with 128x16x16 MMA-2CTA +TEST(SM100Only_Device_Gemm_f16t_f16n_void_f32n_tensor_op_f32, 128x16x64_2x2x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_16,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP16N x FP16T -> FP32 with 128x16x16 MMA-2CTA +TEST(SM100Only_Device_Gemm_f16n_f16t_void_f32n_tensor_op_f32, 128x16x64_2x2x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_16,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP16T x FP16N -> FP32 with 256x16x16 MMA-2CTA +TEST(SM100Only_Device_Gemm_f16t_f16n_void_f32n_tensor_op_f32, 256x16x64_2x2x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_16,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP16N x FP16T -> FP32 with 256x16x16 MMA-2CTA +TEST(SM100Only_Device_Gemm_f16n_f16t_void_f32n_tensor_op_f32, 256x16x64_2x2x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_16,_64>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_2,_2,_1>; + + // Describe A and B tensors + using ElementA = cutlass::half_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = float; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif diff --git a/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu b/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu new file mode 100644 index 00000000..69fc4683 --- /dev/null +++ b/test/unit/gemm/device/sm100_tensorop_gemm/f8_f8_void_bf16_narrow_mma_n.cu @@ -0,0 +1,922 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../../common/cutlass_unit_test.h" + +#include "../gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +// FP8T x FP8N -> FP32 with 64x8x32 MMA-1CTA +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_void_bf16n_tensor_op_f32, 64x8x128_4x1x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_8,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8N -> FP32 with 64x8x32 MMA-1CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3n_void_bf16n_tensor_op_f32, 64x8x128_4x1x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_8,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8T -> FP32 with 64x16x32 MMA-1CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3t_void_bf16n_tensor_op_f32, 64x16x128_4x1x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_64,_16,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8T x FP8N -> FP32 with 128x8x32 MMA-1CTA +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_void_bf16n_tensor_op_f32, 128x8x128_4x1x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_8,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8N -> FP32 with 128x8x32 MMA-1CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3n_void_bf16n_tensor_op_f32, 128x8x128_4x1x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_8,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8T -> FP32 with 128x16x32 MMA-1CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3t_void_bf16n_tensor_op_f32, 128x16x128_4x1x1_1sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_16,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized1Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8T x FP8N -> FP32 with 128x16x32 MMA-2CTA +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_void_bf16n_tensor_op_f32, 128x16x128_4x1x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_16,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8N -> FP32 with 128x16x32 MMA-2CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3n_void_bf16n_tensor_op_f32, 128x16x128_4x1x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_16,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8T -> FP32 with 128x32x32 MMA-2CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3t_void_bf16n_tensor_op_f32, 128x32x128_4x1x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_128,_32,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8T x FP8N -> FP32 with 256x16x32 MMA-2CTA +TEST(SM100Only_Device_Gemm_e4m3t_e4m3n_void_bf16n_tensor_op_f32, 256x16x128_4x1x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_16,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8N -> FP32 with 256x16x32 MMA-2CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3n_void_bf16n_tensor_op_f32, 256x16x128_4x1x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_16,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::ColumnMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +// FP8N x FP8T -> FP32 with 256x32x32 MMA-2CTA +TEST(SM100Only_Device_Gemm_e4m3n_e4m3t_void_bf16n_tensor_op_f32, 256x32x128_4x1x1_2sm) { + // Tile and cluster shapes + // Collective MMA takes tile shape of the MMA operation as input + using MmaTileShape_MNK = Shape<_256,_32,_128>; + // Cluster size for multicast + using ClusterShape_MNK = Shape<_4,_1,_1>; + + // Describe A and B tensors + using ElementA = cutlass::float_e4m3_t; + constexpr int AlignA = 128 / sizeof_bits::value; + using GmemLayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::float_e4m3_t; + constexpr int AlignB = 128 / sizeof_bits::value; + using GmemLayoutB = cutlass::layout::RowMajor; + + // Describe C and D tensors + using ElementC = void; + constexpr int AlignC = 0; + using GmemLayoutC = cutlass::layout::ColumnMajor; + using ElementD = cutlass::bfloat16_t; + constexpr int AlignD = 128 / sizeof_bits::value; + using GmemLayoutD = cutlass::layout::ColumnMajor; + + // Mma's accumulator type + using ElementAccumulator = float; + // Epilogue computation's precision type + using ElementCompute = float; + + // + // Construct CollectiveEpilogue + // + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape + ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue + ElementC, GmemLayoutC, AlignC, // C tensor description + ElementD, GmemLayoutD, AlignD, // D tensor description + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy + >::CollectiveOp; + + // + // Construct CollectiveMainloop + // + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, // Arch and Tensorop spec + ElementA, GmemLayoutA, AlignA, // A tensor elem type, layout and alignment requirement + ElementB, GmemLayoutB, AlignB, // B tensor elem type, layout and alignment requirement + ElementAccumulator, // Mma instruction accumulator type + MmaTileShape_MNK, ClusterShape_MNK, // Mma instruction tile shape, cluster shape + // Epilogue's SMEM usage that needs to be subtracted from overall SMEM capacity + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmSm100 // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + // Create Gemm Kernel using CollectiveEpilogue and CollectiveMainloop created by the builders + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + // Run tests + auto pass = test::gemm::device::TestAll(); + // Check results + EXPECT_TRUE(pass); +} + +#endif diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm.cu new file mode 100644 index 00000000..16e9fe4a --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm.cu @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32, 256x256x768_2x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32, 256x512x768_2x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cute::Shape, // needs 128x128 block for VS16 case + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32, 512x384x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cute::Shape, // needs 128x128 block for VS16 case + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x128.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x128.cu new file mode 100644 index 00000000..e4721849 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x128.cu @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs32, 512x256x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16, 512x256x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x192.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x192.cu new file mode 100644 index 00000000..4db4710c --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x192.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs32, 512x384x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + // cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16, 512x384x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + // cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x256.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x256.cu new file mode 100644 index 00000000..0ac4b257 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_1sm_128x256.cu @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs32, 512x512x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16, 512x512x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + cutlass::half_t, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm.cu new file mode 100644 index 00000000..ad1b3b93 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm.cu @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32, 256x256x768_2x1x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32, 256x512x768_2x4x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x128.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x128.cu new file mode 100644 index 00000000..3827379c --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x128.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs32, 512x256x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + // cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16, 512x256x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + // cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x192.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x192.cu new file mode 100644 index 00000000..128420d9 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x192.cu @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs32, 512x384x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + // cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16, 512x384x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + // cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x256.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x256.cu new file mode 100644 index 00000000..bbabad4c --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_2sm_256x256.cu @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs32, 512x512x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cute::Shape, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16, 512x512x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cute::Shape, // 128x64 and 128x128 are both workabled for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x128.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x128.cu new file mode 100644 index 00000000..9da35a9a --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x128.cu @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16_group, 512x256x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs32_group, 256x512x768_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16_group, 256x128x768_2x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x192.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x192.cu new file mode 100644 index 00000000..42399986 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_1sm_128x192.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs32_group, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16_group, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x192.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x192.cu new file mode 100644 index 00000000..a7a47d90 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x192.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs32_group, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16_group, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x256.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x256.cu new file mode 100644 index 00000000..4b7cc9c5 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_group_2sm_256x256.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs32_group, 512x512x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16_group, 512x512x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC *, 4, + float, LayoutC *, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA *, 32, + cute::tuple, LayoutB *, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::GroupProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x128.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x128.cu new file mode 100644 index 00000000..6109e938 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x128.cu @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16_ptr_array, 512x256x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs32_ptr_array, 256x512x768_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16_ptr_array, 256x128x768_2x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x192.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x192.cu new file mode 100644 index 00000000..f596ad33 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_1sm_128x192.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs32_ptr_array, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_vs16_ptr_array, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x192.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x192.cu new file mode 100644 index 00000000..3ee92be9 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x192.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs32_ptr_array, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16_ptr_array, 512x384x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x256.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x256.cu new file mode 100644 index 00000000..51052599 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_ptr_array_2sm_256x256.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_ptr_array.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs32_ptr_array, 512x512x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16_ptr_array, 512x512x768_4x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, LayoutA, 32, + cute::tuple, LayoutB, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_stream_k.cu b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_stream_k.cu new file mode 100644 index 00000000..0c37a88a --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_f4_f32_tensor_op_f32_stream_k.cu @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_streamk, 256x256x768_2x1x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_streamk, 256x512x768_2x4x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_2sm_f32_vs16_streamk, 512x384x768_4x2x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + // cute::Shape, // We need 128x128 block for VS16 case for both 1SM and 2SM kernels + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::TmaWarpSpecialized2Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm103_gemm_f4_tensor_op_f32_nosmem.cu b/test/unit/gemm/device/sm103_gemm_f4_tensor_op_f32_nosmem.cu new file mode 100644 index 00000000..9be7e5f5 --- /dev/null +++ b/test/unit/gemm/device/sm103_gemm_f4_tensor_op_f32_nosmem.cu @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" +#include "cutlass/arch/mma_sm100.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) + +TEST(SM103_Device_Gemm_e2m1t_e2m1n_f32t_tensorop_1sm_f32_nosmem, 256x128x768_2x1x1) { + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::Shape>, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, cutlass::layout::RowMajor, 4, + float, cutlass::layout::RowMajor, 4, + cutlass::epilogue::NoSmemWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm103, cutlass::arch::OpClassBlockScaledTensorOp, + cute::tuple, cutlass::layout::RowMajor, 32, + cute::tuple, cutlass::layout::ColumnMajor, 32, + float, + cute::Shape>, + cute::Shape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103 + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using namespace test::gemm::device; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + auto pass = test::gemm::device::TestSmall(1.0, 0.0); + EXPECT_TRUE(pass); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM103_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu new file mode 100644 index 00000000..5cd5e503 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_f32_blockwise.cu @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/device_memory.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +template +bool groupwise_test( + Int, Int, Int, + LayoutA, LayoutB, LayoutCD, + MmaTileShape, ClusterShape) { + + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutCD, 16, + cutlass::float_e4m3_t, LayoutCD, 16, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, cute::tuple, 16, + cutlass::float_e4m3_t, cute::tuple, 16, + float, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + // Strides just iterate over scalars and have no zeros + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + + int alignment_M = max(max((is_same_v ? 16 : 1) , + (SFAMajor == cute::GMMA::Major::MN ? CollectiveMainloop::AlignmentSFA : 1)), + (is_same_v ? 16 : 1)); + + int alignment_N = max(max((is_same_v ? 16 : 1) , + (SFBMajor == cute::GMMA::Major::MN ? CollectiveMainloop::AlignmentSFB : 1)), + (is_same_v ? 16 : 1)); + + int alignment_K = max(max((is_same_v ? 16 : 1) , + (SFAMajor == cute::GMMA::Major::K ? CollectiveMainloop::AlignmentSFA : 1)), + max((is_same_v ? 16 : 1) , + (SFBMajor == cute::GMMA::Major::K ? CollectiveMainloop::AlignmentSFB : 1))); + + alignment_K = (alignment_K / size<2>(MmaTileShape{}) + 1) * size<2>(MmaTileShape{}); + + int M = 1024 + alignment_M; + int N = 1024 + alignment_N; + int K = 512 + alignment_K; + EXPECT_TRUE(M % alignment_M == 0); + EXPECT_TRUE(N % alignment_N == 0); + EXPECT_TRUE(K % alignment_K == 0); + EXPECT_TRUE(K % size<2>(MmaTileShape{}) == 0); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, 1)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, 1)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, 1)); + + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + thrust::universal_vector tensor_A(M * K); + thrust::universal_vector tensor_SFA(cute::size(cute::filter_zeros(layout_SFA))); + thrust::universal_vector tensor_B(N * K); + thrust::universal_vector tensor_SFB(cute::size(cute::filter_zeros(layout_SFB))); + thrust::universal_vector tensor_C(M * N); + thrust::universal_vector tensor_D(M * N); + thrust::universal_vector tensor_ref_D(M * N); + + thrust::random::default_random_engine engine(2025); + thrust::random::uniform_int_distribution dist(-2, 2); + + std::generate(tensor_A.begin(), tensor_A.end(), [&] () { + return static_cast(dist(engine)); + }); + std::generate(tensor_SFA.begin(), tensor_SFA.end(), [&] () { + return static_cast(dist(engine)); + }); + std::generate(tensor_B.begin(), tensor_B.end(), [&] () { + return static_cast(dist(engine)); + }); + std::generate(tensor_SFB.begin(), tensor_SFB.end(), [&] () { + return static_cast(dist(engine)); + }); + std::generate(tensor_C.begin(), tensor_C.end(), [&] () { + return static_cast(dist(engine)); + }); + + typename Gemm::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + {thrust::raw_pointer_cast(tensor_A.data()), stride_A, + thrust::raw_pointer_cast(tensor_B.data()), stride_B, + thrust::raw_pointer_cast(tensor_SFA.data()), layout_SFA, + thrust::raw_pointer_cast(tensor_SFB.data()), layout_SFB}, + { + {}, // epilogue.thread + thrust::raw_pointer_cast(tensor_C.data()), stride_C, + thrust::raw_pointer_cast(tensor_D.data()), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = 1.0f; + fusion_args.beta = 1.0f; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + Gemm gemm; + + EXPECT_TRUE(gemm.can_implement(arguments) == cutlass::Status::kSuccess); + EXPECT_TRUE(gemm.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess); + EXPECT_TRUE(gemm.run() == cutlass::Status::kSuccess); + EXPECT_TRUE(cudaDeviceSynchronize() == cudaSuccess); + + auto A = cute::make_tensor(thrust::raw_pointer_cast(tensor_A.data()), + cute::make_layout(cute::make_shape(M, K, 1), stride_A)); + auto B = cute::make_tensor(thrust::raw_pointer_cast(tensor_B.data()), + cute::make_layout(cute::make_shape(N, K, 1), stride_B)); + auto C = cute::make_tensor(thrust::raw_pointer_cast(tensor_C.data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_C)); + auto D = cute::make_tensor(thrust::raw_pointer_cast(tensor_ref_D.data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_D)); + auto SFA = cute::make_tensor(thrust::raw_pointer_cast(tensor_SFA.data()), layout_SFA); + auto SFB = cute::make_tensor(thrust::raw_pointer_cast(tensor_SFB.data()), layout_SFB); + + cutlass::reference::host::GettBlockScalingMainloopParams< + float, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + float, + float, + float, + float, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = 1.0f; + epilogue_params.beta = 1.0f; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + bool equal = true; + for (size_t i = 0; i < tensor_ref_D.size(); ++i) { + equal &= (tensor_ref_D[i] == tensor_D[i]); + } + return equal; +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_f32_align16_blockwise, 128x128x128_1x1x1_1x128x128_scale) { + + bool passed = groupwise_test( + Int<1>{}, Int<128>{}, Int<128>{}, + cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, + cutlass::layout::RowMajor{}, + Shape<_128,_128,_128>{}, + Shape<_1,_1,_1>{}); + + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_f32_align16_blockwise, 128x128x128_1x1x1_1x1x128_scale) { + + bool passed = groupwise_test( + Int<1>{}, Int<128>{}, Int<128>{}, + cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, + cutlass::layout::RowMajor{}, + Shape<_256,_128,_128>{}, + Shape<_2,_1,_1>{}); + + EXPECT_TRUE(passed); + +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3t_tensorop_f32_align16_blockwise, 128x128x128_1x1x1_1x128x128_k_maj_k_maj_scale) { + + bool passed = groupwise_test( + Int<1>{}, Int<128>{}, Int<128>{}, + cutlass::layout::RowMajor{}, cutlass::layout::ColumnMajor{}, + cutlass::layout::RowMajor{}, + Shape<_128,_128,_128>{}, + Shape<_1,_1,_1>{}); + + EXPECT_TRUE(passed); + +} + +#endif // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 57c584d5..90dfffa4 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -283,12 +283,34 @@ file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOU set(CUTLASS_GENERATOR_CUDA_COMPILER_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) set(CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/generated_kernels.txt CACHE STRING "Generated kernel listing file") +set(CUTLASS_LIBRARY_HEURISTICS_TESTLIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/heuristics.csv CACHE STRING "Generated heuristics configs CSV") +set(CUTLASS_LIBRARY_HEURISTICS_GPU "" CACHE STRING "GPU to use for GEMM heuristics") +set(CUTLASS_LIBRARY_HEURISTICS_RESTRICT_KERNELS OFF CACHE BOOL + "Restrict heuristics kernels to only the default set of kernels emitted by generator.py") + + +if(CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE) + set(HEURISTICS_ARGS + --heuristics-problems-file "${CUTLASS_LIBRARY_HEURISTICS_PROBLEMS_FILE}" + --heuristics-testlist-file "${CUTLASS_LIBRARY_HEURISTICS_TESTLIST_FILE}" + --heuristics-configs-per-problem "${CUTLASS_LIBRARY_HEURISTICS_CONFIGS_PER_PROBLEM}" + ) + + if(CUTLASS_LIBRARY_HEURISTICS_RESTRICT_KERNELS) + list(APPEND HEURISTICS_ARGS --heuristics-restrict-kernels) + endif() + + if(CUTLASS_LIBRARY_HEURISTICS_GPU) + list(APPEND HEURISTICS_ARGS --heuristics-gpu "${CUTLASS_LIBRARY_HEURISTICS_GPU}") + endif() +endif() + # --log-level is set to DEBUG to enable printing information about which kernels were excluded # from generation in /python/cutlass_library/manifest.py. To avoid having this information appear # in ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log, set this parameter to INFO execute_process( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../python/cutlass_library - COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR} + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR}:${CUTLASS_NVMMH_PY_DIR} ${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py --operations "${CUTLASS_LIBRARY_OPERATIONS}" --build-dir ${PROJECT_BINARY_DIR} @@ -304,6 +326,7 @@ execute_process( --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" --log-level INFO --disable-cutlass-package-imports + ${HEURISTICS_ARGS} RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log diff --git a/tools/library/include/cutlass/library/arch_mappings.h b/tools/library/include/cutlass/library/arch_mappings.h index 34939b33..df241e3c 100644 --- a/tools/library/include/cutlass/library/arch_mappings.h +++ b/tools/library/include/cutlass/library/arch_mappings.h @@ -127,12 +127,25 @@ template struct ArchMap { template <> struct ArchMap { static int const kMin = 100; - static int const kMax = 101; + #if (__CUDACC_VER_MAJOR__ >= 13) + static int const kMax = 110; + #else + static int const kMax = 103; + #endif // __CUDACC_VER_MAJOR__ >= 13 +}; + +template struct ArchMap { + static int const kMin = 103; + static int const kMax = 1024; +}; +template <> struct ArchMap { + static int const kMin = 103; + static int const kMax = 103; }; template struct ArchMap { static int const kMin = 120; - static int const kMax = 120; + static int const kMax = 121; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index f2b10dac..2e343589 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -104,7 +104,11 @@ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 A set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) else() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) - set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) + if (90a IN_LIST CUTLASS_NVCC_ARCHS_ENABLED OR (90 IN_LIST CUTLASS_NVCC_ARCHS_ENABLED)) + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) + else() + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=BlockScaledGemm --mode=trace --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) + endif() endif() set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true) diff --git a/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h index 368fabb2..5d500d91 100644 --- a/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h @@ -180,6 +180,8 @@ public: /// Buffer used for the cutlass reduction operations' host workspace std::vector reduction_host_workspace; + + cudaStream_t stream; }; protected: diff --git a/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h index c6a1aa35..c110de27 100644 --- a/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h @@ -100,6 +100,13 @@ public: cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; int swizzle_size{1}; + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; @@ -122,6 +129,14 @@ public: ProblemSpace const &problem_space, ProblemSpace::Problem const &problem); + int64_t bytes_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + /// Total number of bytes loaded int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const; diff --git a/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h index d41b1ba5..62d47990 100644 --- a/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -176,6 +176,8 @@ public: std::vector host_workspace; DeviceAllocation device_workspace; + + cudaStream_t stream; }; private: diff --git a/tools/profiler/include/cutlass/profiler/options.h b/tools/profiler/include/cutlass/profiler/options.h index 0800e440..1a957b36 100644 --- a/tools/profiler/include/cutlass/profiler/options.h +++ b/tools/profiler/include/cutlass/profiler/options.h @@ -346,6 +346,10 @@ public: /// Vector of operation name substrings std::vector operation_names; + /// Map of problems to run for each operation + /// [operation_name] -> vector of problems, each problem specified as a vector of [argument name] -> [argument value] + std::unordered_map> operation_problems; + /// Vector of operation name substrings std::vector excluded_operation_names; diff --git a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu index 73239dce..84faffbe 100644 --- a/tools/profiler/src/block_scaled_gemm_operation_profiler.cu +++ b/tools/profiler/src/block_scaled_gemm_operation_profiler.cu @@ -176,7 +176,7 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) { // default value - this->cluster_m = 1; + this->cluster_m = std::string(operation_desc.name).find("_2sm") != std::string::npos ? 2 : 1; } if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) { @@ -191,17 +191,17 @@ Status BlockScaledGemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { // default value - this->cluster_m_fallback = 0; + this->cluster_m_fallback = (this->cluster_m % 2 == 0) ? 2 : 1; } if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { // default value - this->cluster_n_fallback = 0; + this->cluster_n_fallback = 1; } if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) { // default value - this->cluster_k_fallback = 0; + this->cluster_k_fallback = 1; } if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) { @@ -540,6 +540,8 @@ Status BlockScaledGemmOperationProfiler::initialize_configuration( gemm_workspace_.arguments.use_pdl = problem_.use_pdl; + cudaStreamCreateWithFlags(&gemm_workspace_.stream, cudaStreamNonBlocking); + // initialize reduction operation for parallel splitKMode if (problem_.split_k_mode == library::SplitKMode::kParallel) { if (!initialize_reduction_configuration_(operation, problem)) { @@ -643,6 +645,7 @@ void BlockScaledGemmOperationProfiler::initialize_result_( result.bytes = problem_.bytes(operation_desc); result.flops = problem_.flops(operation_desc); result.runtime = 0; + result.runtime_vector.resize(options.device.devices.size(), 0); } @@ -1578,7 +1581,7 @@ Status BlockScaledGemmOperationProfiler::profile_cutlass_( } } - auto func = [&](cudaStream_t, int iteration) { + auto func = [&](cudaStream_t stream, int iteration) { // Iterate over copies of the problem in memory int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; @@ -1599,7 +1602,7 @@ Status BlockScaledGemmOperationProfiler::profile_cutlass_( arguments, host_workspace, device_workspace, - nullptr); + stream); if (status != Status::kSuccess) { return status; @@ -1611,7 +1614,7 @@ Status BlockScaledGemmOperationProfiler::profile_cutlass_( &gemm_workspace_.reduction_arguments, gemm_workspace_.reduction_host_workspace.data(), nullptr, - nullptr); + stream); if (status != Status::kSuccess) { return status; @@ -1621,7 +1624,7 @@ Status BlockScaledGemmOperationProfiler::profile_cutlass_( return status; }; - return profile_kernel_(result, options, func); + return profile_kernel_(result, options, func, gemm_workspace_.stream); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/src/blockwise_gemm_operation_profiler.cu b/tools/profiler/src/blockwise_gemm_operation_profiler.cu index 4a8e2543..2e20dee3 100644 --- a/tools/profiler/src/blockwise_gemm_operation_profiler.cu +++ b/tools/profiler/src/blockwise_gemm_operation_profiler.cu @@ -116,21 +116,21 @@ void BlockwiseGemmOperationProfiler::print_examples(std::ostream &out) const { << " $ cutlass_profiler --operation=blockwise_gemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" << "For column major, use column, col, or n. For row major use, row or t:\n" - << " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n" + << " $ cutlass_profiler --operation=blockwise_gemm --A=f16:column --B=*:row\n\n" << "Profile a particular problem size with split K and parallel reduction:\n" - << " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n" + << " $ cutlass_profiler --operation=blockwise_gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n" << "Using various input value distribution:\n" - << " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n" - << " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n" - << " $ cutlass_profiler --operation=Gemm --dist=sequential,start:0,delta:1\n\n" + << " $ cutlass_profiler --operation=blockwise_gemm --dist=uniform,min:0,max:3\n" + << " $ cutlass_profiler --operation=blockwise_gemm --dist=gaussian,mean:0,stddev:3\n" + << " $ cutlass_profiler --operation=blockwise_gemm --dist=sequential,start:0,delta:1\n\n" << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" - << " $ cutlass_profiler --operation=Gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" + << " $ cutlass_profiler --operation=blockwise_gemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" - << " $ cutlass_profiler --operation=Gemm \\ \n" + << " $ cutlass_profiler --operation=blockwise_gemm \\ \n" << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" @@ -194,7 +194,7 @@ Status BlockwiseGemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) { // default value - this->cluster_m = 1; + this->cluster_m = std::string(operation_desc.name).find("_2sm") != std::string::npos ? 2 : 1; } if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) { @@ -209,17 +209,17 @@ Status BlockwiseGemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { // default value - this->cluster_m_fallback = 0; + this->cluster_m_fallback = (this->cluster_m % 2 == 0) ? 2 : 1; } if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { // default value - this->cluster_n_fallback = 0; + this->cluster_n_fallback = 1; } if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) { // default value - this->cluster_k_fallback = 0; + this->cluster_k_fallback = 1; } @@ -331,33 +331,80 @@ Status BlockwiseGemmOperationProfiler::GemmProblem::parse( this->ldc = DeviceAllocation::get_packed_layout( operation_desc.C.layout, {int(this->m), int(this->n)}).front(); + // instantiation + int num_sizes = 8; + this->problem_sizes.resize(num_sizes); + this->leading_dims.resize(num_sizes, {0, 0, 0}); + + int m0 = 1024; + int n0 = 1024; + int k0 = 1024; + for (int i = 0; i < num_sizes; i++) { + auto m = m0 * (i + 1); + auto n = n0 * (i + 1); + auto k = k0 * (i + 1); + this->problem_sizes[i] = {m, n, k}; + this->leading_dims[i] = { + DeviceAllocation::get_packed_layout(operation_desc.A.layout, {int(m), int(k)}).front(), + DeviceAllocation::get_packed_layout(operation_desc.B.layout, {int(k), int(n)}).front(), + DeviceAllocation::get_packed_layout(operation_desc.C.layout, {int(m), int(n)}).front() + }; + + } + + this->swizzle_sizes = {1, 2, 4, 8}; + + this->preferred_clusters = { + {1, 1, 1}, {2, 1, 1}, {2, 2, 1}, {4, 1, 1}, {4, 2, 1}, {4, 4, 1}, {8, 2, 1} + }; + + this->fallback_clusters = { + {1, 1, 1}, {2, 1, 1}, {2, 2, 1} + }; + + this->raster_orders = { + cutlass::library::RasterOrder::kAlongN, + cutlass::library::RasterOrder::kAlongM + }; + return Status::kSuccess; } -/// Total number of bytes loaded -int64_t BlockwiseGemmOperationProfiler::GemmProblem::bytes(library::BlockwiseGemmDescription const &operation_desc) const { - // Input bytes read and Output bytes written for the gemm problem +int64_t BlockwiseGemmOperationProfiler::GemmProblem::bytes_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const { + int64_t bytes = - int64_t(library::sizeof_bits(operation_desc.A.element) * m / 8) * k + - int64_t(library::sizeof_bits(operation_desc.B.element) * n / 8) * k + - int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; + int64_t(library::sizeof_bits(operation_desc.A.element) * problem_shape.m() / 8) * problem_shape.k() + + int64_t(library::sizeof_bits(operation_desc.B.element) * problem_shape.n() / 8) * problem_shape.k() + + int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n() + + int64_t(library::sizeof_bits(operation_desc.SFA.element) * problem_shape.m() / operation_desc.SFMVecSize / 8) * problem_shape.k() / operation_desc.SFKVecSize + + int64_t(library::sizeof_bits(operation_desc.SFB.element) * problem_shape.n() / operation_desc.SFNVecSize / 8) * problem_shape.k() / operation_desc.SFKVecSize; // Set is_beta_zero true if beta is zero bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); // Output bytes read for the gemm problem for non-zero beta values if (!is_beta_zero) { - bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; + bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * problem_shape.m() / 8) * problem_shape.n(); } bytes *= batch_count; return bytes; + +} + +/// Total number of bytes loaded +int64_t BlockwiseGemmOperationProfiler::GemmProblem::bytes(library::BlockwiseGemmDescription const &operation_desc) const { + return bytes_with_problem_shape(operation_desc, {int(m), int(n), int(k)}); } /// Total number of flops computed -int64_t BlockwiseGemmOperationProfiler::GemmProblem::flops(library::BlockwiseGemmDescription const &operation_desc) const { - int64_t flops_ = (int64_t(m) * n * k + m * n) * 2 * batch_count; +int64_t BlockwiseGemmOperationProfiler::GemmProblem::flops_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const { + int64_t flops_ = (int64_t(problem_shape.m()) * problem_shape.n() * problem_shape.k() + problem_shape.m() * problem_shape.n()) * 2 * batch_count; // complex-valued support switch (operation_desc.tile_description.math_instruction.math_operation) { @@ -379,6 +426,10 @@ int64_t BlockwiseGemmOperationProfiler::GemmProblem::flops(library::BlockwiseGem return flops_; } +/// Total number of flops computed +int64_t BlockwiseGemmOperationProfiler::GemmProblem::flops(library::BlockwiseGemmDescription const &operation_desc) const { + return flops_with_problem_shape(operation_desc, {int(m), int(n), int(k)}); +} /// Initializes a performance result void BlockwiseGemmOperationProfiler::GemmProblem::initialize_result( @@ -1185,42 +1236,249 @@ bool BlockwiseGemmOperationProfiler::profile( if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { - // Initialize structure containing GEMM arguments - gemm_workspace_.arguments.A = gemm_workspace_.A->data(); - gemm_workspace_.arguments.B = gemm_workspace_.B->data(); - gemm_workspace_.arguments.SFA = gemm_workspace_.SFA->data(); - gemm_workspace_.arguments.SFB = gemm_workspace_.SFB->data(); - gemm_workspace_.arguments.C = gemm_workspace_.C->data(); - gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); - gemm_workspace_.arguments.alpha = problem_.alpha.data(); - gemm_workspace_.arguments.beta = problem_.beta.data(); - gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + // Case when we either screen the best performance number of kernels with or without a fixed problem shape fed in. + if (options.profiling.enable_kernel_performance_search || options.profiling.enable_best_kernel_for_fixed_shape) { + library::BlockwiseGemmDescription const &operation_desc = + static_cast(operation->description()); - if (problem_.split_k_mode == library::SplitKMode::kParallel) { - gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); - gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); - gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + auto cluster_shape = operation_desc.tile_description.cluster_shape; + bool is_dynamic_cluster_enabled = cluster_shape.m() == 0 || cluster_shape.n() == 0 || cluster_shape.k() == 0; - gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); - gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); - gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); - gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); - gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); - gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + // Helper function wrapping up performance test with flexible parameters. + auto initialize_and_profile = [&]( + PerformanceResult const &result, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size) -> std::optional { + + // Initialize structure containing GEMM arguments + gemm_workspace_.arguments.A = gemm_workspace_.A->data(); + gemm_workspace_.arguments.B = gemm_workspace_.B->data(); + gemm_workspace_.arguments.SFA = gemm_workspace_.SFA->data(); + gemm_workspace_.arguments.SFB = gemm_workspace_.SFB->data(); + gemm_workspace_.arguments.C = gemm_workspace_.C->data(); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); + gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } + + gemm_workspace_.arguments.problem_size.m() = problem_shape.m(); + gemm_workspace_.arguments.problem_size.n() = problem_shape.n(); + gemm_workspace_.arguments.problem_size.k() = problem_shape.k(); + + gemm_workspace_.arguments.lda = leading_dim[0]; + gemm_workspace_.arguments.ldb = leading_dim[1]; + gemm_workspace_.arguments.ldc = leading_dim[2]; + + gemm_workspace_.arguments.swizzle_size = swizzle_size; + gemm_workspace_.arguments.raster_order = raster_order; + + if (is_dynamic_cluster_enabled) { + gemm_workspace_.arguments.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + gemm_workspace_.configuration.cluster_shape = {int(preferred_cluster[0]), int(preferred_cluster[1]), int(preferred_cluster[2])}; + gemm_workspace_.configuration.cluster_shape_fallback = {int(fallback_cluster[0]), int(fallback_cluster[1]), int(fallback_cluster[2])}; + } + + gemm_workspace_.configuration.problem_size.m() = problem_shape.m(); + gemm_workspace_.configuration.problem_size.n() = problem_shape.n(); + gemm_workspace_.configuration.problem_size.k() = problem_shape.k(); + + gemm_workspace_.configuration.lda = leading_dim[0]; + gemm_workspace_.configuration.ldb = leading_dim[1]; + gemm_workspace_.configuration.ldc = leading_dim[2]; + + const auto can_implement = operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); + if (can_implement != Status::kSuccess) { + return std::nullopt; // Return nullopt to indicate failure + } + library::Operation const* underlying_operation = operation; + uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration); + gemm_workspace_.host_workspace.resize(workspace_size, 0); + + workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, + &gemm_workspace_.arguments); + + gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + + Status status = underlying_operation->initialize( + &gemm_workspace_.configuration, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data(), + nullptr); + + if (status != Status::kSuccess) { + return std::nullopt; // Return nullopt to indicate failure + } + + + PerformanceResult curr_result(result); + curr_result.bytes = problem_.bytes_with_problem_shape(operation_desc, problem_shape); + curr_result.flops = problem_.flops_with_problem_shape(operation_desc, problem_shape); + + set_argument(curr_result, "m", problem_space, problem_shape.m()); + set_argument(curr_result, "n", problem_space, problem_shape.n()); + set_argument(curr_result, "k", problem_space, problem_shape.k()); + + set_argument(curr_result, "raster_order", problem_space, library::to_string(raster_order)); + set_argument(curr_result, "swizzle_size", problem_space, swizzle_size); + + if (is_dynamic_cluster_enabled) { + set_argument(curr_result, "cluster_m", problem_space, preferred_cluster[0]); + set_argument(curr_result, "cluster_n", problem_space, preferred_cluster[1]); + set_argument(curr_result, "cluster_k", problem_space, preferred_cluster[2]); + set_argument(curr_result, "cluster_m_fallback", problem_space, fallback_cluster[0]); + set_argument(curr_result, "cluster_n_fallback", problem_space, fallback_cluster[1]); + set_argument(curr_result, "cluster_k_fallback", problem_space, fallback_cluster[2]); + } + + + curr_result.status = profile_cutlass_( + curr_result, + options, + operation, + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data() + ); + + return curr_result; + }; + + // Helper function to test validity of fallback cluster shapes and preferred cluster shapes. + auto is_valid_dynamic_cluster_shape = [](const std::array& preferred_cluster, const std::array& fallback_cluster) { + for (size_t i = 0; i < 3; ++i) { + if (preferred_cluster[i] % fallback_cluster[i] != 0) { + return false; + } + } + return true; + }; + + // Helper function to select the best performance number among a list. + auto select_best_candidate = [&](std::vector &candidates) { + assert(!candidates.empty() && "Candidates vector should not be empty"); + auto best_iter = std::max_element( + candidates.begin(), candidates.end(), + [](PerformanceResult const &a, PerformanceResult const &b) { + return a.gflops_per_sec() < b.gflops_per_sec(); + } + ); + assert(best_iter != candidates.end() && "No candidate found despite non-empty candidates vector"); + results_.push_back(std::move(*best_iter)); + }; + + std::vector candidates; + PerformanceResult result_base = results_.back(); + results_.pop_back(); + + std::vector> preferred_clusters; + std::vector> fallback_clusters; + + // Only loop over built-in cluster shape lists for dynamic cluster kernels + // and for kernels that can leverage the dynamic cluster feature. + if (is_dynamic_cluster_enabled) { + preferred_clusters = this->problem_.preferred_clusters; + fallback_clusters = this->problem_.fallback_clusters; + } + else { + preferred_clusters = {{int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}}; + fallback_clusters = {{int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}}; + } + + for (auto preferred_cluster : preferred_clusters) { + for (auto fallback_cluster : fallback_clusters) { + if (is_dynamic_cluster_enabled && !is_valid_dynamic_cluster_shape(preferred_cluster, fallback_cluster)) { + continue; + } + for (auto swizzle_size : this->problem_.swizzle_sizes) { + for (auto raster_order : this->problem_.raster_orders) { + // With the fixed shape option turned on, only a specific problem shape is tested. + if (options.profiling.enable_best_kernel_for_fixed_shape) { + this->problem_.problem_sizes = {{int(this->problem_.m), int(this->problem_.n), int(this->problem_.k)}}; + this->problem_.leading_dims = {{this->problem_.lda, this->problem_.ldb, this->problem_.ldc}}; + } + + for (int i = 0; i < int(this->problem_.problem_sizes.size()); i++) { + gemm::GemmCoord problem_shape = problem_.problem_sizes[i]; + std::array leading_dim = problem_.leading_dims[i]; + auto result_opt = initialize_and_profile(result_base, problem_shape, leading_dim, preferred_cluster, fallback_cluster, raster_order, swizzle_size); + + if (result_opt) { // Only add valid results + candidates.push_back(*result_opt); + } + + } + + }// for raster_order + }// for swizzle_size + }// for fallback_cluster + }// for swizzle_size + + if (candidates.empty()) { + return false; + } + + select_best_candidate(candidates); } + else { + // Initialize structure containing GEMM arguments + gemm_workspace_.arguments.A = gemm_workspace_.A->data(); + gemm_workspace_.arguments.B = gemm_workspace_.B->data(); + gemm_workspace_.arguments.SFA = gemm_workspace_.SFA->data(); + gemm_workspace_.arguments.SFB = gemm_workspace_.SFB->data(); + gemm_workspace_.arguments.C = gemm_workspace_.C->data(); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); - results_.back().status = profile_cutlass_( - results_.back(), - options, - operation, - &gemm_workspace_.arguments, - gemm_workspace_.host_workspace.data(), - gemm_workspace_.device_workspace.data() - ); + if (problem_.split_k_mode == library::SplitKMode::kParallel) { + gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data(); + gemm_workspace_.arguments.alpha = problem_.alpha_one.data(); + gemm_workspace_.arguments.beta = problem_.beta_zero.data(); + + gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data(); + gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data(); + gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data(); + gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data(); + gemm_workspace_.reduction_arguments.beta = problem_.beta.data(); + gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost; + } + + results_.back().status = profile_cutlass_( + results_.back(), + options, + operation, + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data() + ); + } } return true; } diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index f932772d..48e451e7 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -176,7 +176,7 @@ Status GemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->cluster_m, "cluster_m", problem_space, problem)) { // default value - this->cluster_m = 1; + this->cluster_m = std::string(operation_desc.name).find("_2sm") != std::string::npos ? 2 : 1; } if (!arg_as_int(this->cluster_n, "cluster_n", problem_space, problem)) { @@ -191,17 +191,17 @@ Status GemmOperationProfiler::GemmProblem::parse( if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { // default value - this->cluster_m_fallback = 0; + this->cluster_m_fallback = (this->cluster_m % 2 == 0) ? 2 : 1; } if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { // default value - this->cluster_n_fallback = 0; + this->cluster_n_fallback = 1; } if (!arg_as_int(this->cluster_k_fallback, "cluster_k_fallback", problem_space, problem)) { // default value - this->cluster_k_fallback = 0; + this->cluster_k_fallback = 1; } if (!arg_as_bool(this->use_pdl, "use_pdl", problem_space, problem)) { diff --git a/tools/profiler/src/grouped_gemm_operation_profiler.cu b/tools/profiler/src/grouped_gemm_operation_profiler.cu index 4ef9f564..2297fab1 100644 --- a/tools/profiler/src/grouped_gemm_operation_profiler.cu +++ b/tools/profiler/src/grouped_gemm_operation_profiler.cu @@ -283,7 +283,7 @@ Status GroupedGemmOperationProfiler::GroupedGemmProblem::parse( if (!arg_as_int(this->cluster_m_fallback, "cluster_m_fallback", problem_space, problem)) { // default value - this->cluster_m_fallback = std::string(operation_desc.gemm.name).find("_2sm") != std::string::npos ? 2 : 1; + this->cluster_m_fallback = (this->cluster_m % 2 == 0) ? 2 : 1; } if (!arg_as_int(this->cluster_n_fallback, "cluster_n_fallback", problem_space, problem)) { @@ -408,6 +408,11 @@ int64_t GroupedGemmOperationProfiler::GroupedGemmProblem::bytes( for (size_t group_idx = 0, num_groups = problem_sizes.size(); group_idx < num_groups; group_idx++) { + // If M = 0 or N = 0, no tiles are scheduled and no bytes are loaded for the group + if (m(group_idx) * n(group_idx) == 0) { + continue; + } + bytes += int64_t(library::sizeof_bits(operation_desc.gemm.A.element) * m(group_idx) / 8) * k(group_idx) + int64_t(library::sizeof_bits(operation_desc.gemm.B.element) * n(group_idx) / 8) * k(group_idx) + @@ -630,6 +635,8 @@ Status GroupedGemmOperationProfiler::initialize_configuration( gemm_workspace_.arguments.use_pdl = problem_.use_pdl; + cudaStreamCreateWithFlags(&gemm_workspace_.stream, cudaStreamNonBlocking); + initialize_result_(this->model_result_, options, operation_desc, problem_space); return status; @@ -654,6 +661,7 @@ void GroupedGemmOperationProfiler::initialize_result_( result.bytes = problem_.bytes(operation_desc); result.flops = problem_.flops(operation_desc); result.runtime = 0; + result.runtime_vector.resize(options.device.devices.size(), 0); } @@ -1585,9 +1593,9 @@ Status GroupedGemmOperationProfiler::profile_cutlass_( void* host_workspace, void* device_workspace) { library::Operation const* underlying_operation = operation; - results_.back().status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments); - if (results_.back().status != Status::kSuccess) { - return results_.back().status; + result.status = underlying_operation->initialize_with_arguments(&gemm_workspace_.arguments); + if (result.status != Status::kSuccess) { + return result.status; } auto func = [&](cudaStream_t stream, int iteration) { @@ -1600,9 +1608,9 @@ Status GroupedGemmOperationProfiler::profile_cutlass_( gemm_workspace_.arguments.ptr_C = gemm_workspace_.C_ptr_array_device[problem_idx]->data(); gemm_workspace_.arguments.ptr_D = gemm_workspace_.D_ptr_array_device[problem_idx]->data(); - return underlying_operation->run(arguments, host_workspace, device_workspace); + return underlying_operation->run(arguments, host_workspace, device_workspace, stream); }; - return profile_kernel_(result, options, func); + return profile_kernel_(result, options, func, gemm_workspace_.stream); } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index 387640b2..e968aeb8 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -348,132 +348,127 @@ int OperationProfiler::profile_all( Options const &options, library::Manifest const &manifest, DeviceContext &device_context) { - ProblemSpace problem_space(arguments_, options.cmdline); + ProblemSpace cmdline_problem_space(arguments_, options.cmdline); + + bool do_testlist_run = !options.operation_problems.empty(); + + std::vector>> all_operations_and_problems; + if (do_testlist_run) { + for (const auto& [operation_name, cmd_vec] : options.operation_problems) { + for (auto& cmd_line : cmd_vec) { + all_operations_and_problems.push_back({operation_name, std::make_unique(arguments_, cmd_line)}); + } + } + } // 1. Construct performance report - PerformanceReport report(options, problem_space.argument_names(), kind_); + PerformanceReport report(options, cmdline_problem_space.argument_names(), kind_); - // 2. For each problem in problem space - ProblemSpace::Iterator problem_it = problem_space.begin(); - ProblemSpace::Iterator problem_end = problem_space.end(); - - bool continue_profiling = true; + // int retval = 0; - // For each problem in problem space - for (; continue_profiling && problem_it != problem_end; ++problem_it) { - ProblemSpace::Problem problem = problem_it.at(); - report.next_problem(); + size_t bound = (all_operations_and_problems.empty() ? 1 : all_operations_and_problems.size()); + for (size_t i = 0; i < bound; i++) { - // For each operation in manifest - int matched_operation_count = 0; - int profiled_operation_count = 0; - for (auto const& operation_ptr : manifest) { + // New problem space for each operation if we are running a testlist + ProblemSpace& problem_space = do_testlist_run ? *all_operations_and_problems[i].second : cmdline_problem_space; - library::Operation const *operation = operation_ptr.get(); + // 2. For each problem in problem space + ProblemSpace::Iterator problem_it = problem_space.begin(); + ProblemSpace::Iterator problem_end = problem_space.end(); + + bool continue_profiling = true; + + // For each problem in problem space + for (; continue_profiling && problem_it != problem_end; ++problem_it) { + ProblemSpace::Problem problem = problem_it.at(); + report.next_problem(); + + // For each operation in manifest + int matched_operation_count = 0; + int profiled_operation_count = 0; + for (auto const& operation_ptr : manifest) { + + library::Operation const *operation = operation_ptr.get(); #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " Operation: " << typeid(*operation).name() << "\n" - << " name: " << operation->description().name << "\n" - << " kind: " << operation->description().kind << "\n" - << " provider: " << operation->description().provider << "\n"; + std::cerr << " Operation: " << typeid(*operation).name() << "\n" + << " name: " << operation->description().name << "\n" + << " kind: " << operation->description().kind << "\n" + << " provider: " << operation->description().provider << "\n"; #endif // CUTLASS_DEBUG_TRACE_LEVEL - auto min_cc = operation->description().tile_description.minimum_compute_capability; - auto max_cc = operation->description().tile_description.maximum_compute_capability; + auto min_cc = operation->description().tile_description.minimum_compute_capability; + auto max_cc = operation->description().tile_description.maximum_compute_capability; #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - std::cerr << " min_cc: " << min_cc << "\n"; - std::cerr << " max_cc: " << min_cc << "\n"; + std::cerr << " min_cc: " << min_cc << "\n"; + std::cerr << " max_cc: " << min_cc << "\n"; #endif - // Clear named allocations - device_context.free(); + // Clear named allocations + device_context.free(); #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) - if (operation->description().kind != kind_) { - std::cerr << " @ kind " << operation->description().kind - << " != kind_ " << kind_ << "\n"; - } - if (operation->description().provider != library::Provider::kCUTLASS) { - std::cerr << " @ provider " << operation->description().provider - << " != library::Provider::kCUTLASS\n"; - } - if (options.device.compute_capability(0) < min_cc) { - std::cerr << " @ compute_capability " - << options.device.compute_capability(0) - << " < min_cc " << min_cc << "\n"; - } - if (options.device.compute_capability(0) > max_cc) { - std::cerr << " @ compute_capability " - << options.device.compute_capability(0) - << " > max_cc " << max_cc << "\n"; - } + if (operation->description().kind != kind_) { + std::cerr << " @ kind " << operation->description().kind + << " != kind_ " << kind_ << "\n"; + } + if (operation->description().provider != library::Provider::kCUTLASS) { + std::cerr << " @ provider " << operation->description().provider + << " != library::Provider::kCUTLASS\n"; + } + if (options.device.compute_capability(0) < min_cc) { + std::cerr << " @ compute_capability " + << options.device.compute_capability(0) + << " < min_cc " << min_cc << "\n"; + } + if (options.device.compute_capability(0) > max_cc) { + std::cerr << " @ compute_capability " + << options.device.compute_capability(0) + << " > max_cc " << max_cc << "\n"; + } #endif - // Execute compatible cutlass operations if they satisfy the current device's compute capability - if (operation->description().kind == kind_ && - operation->description().provider == library::Provider::kCUTLASS && - options.device.compute_capability(0) >= min_cc && - options.device.compute_capability(0) <= max_cc) { + // Execute compatible cutlass operations if they satisfy the current device's compute capability + if (operation->description().kind == kind_ && + operation->description().provider == library::Provider::kCUTLASS && + options.device.compute_capability(0) >= min_cc && + options.device.compute_capability(0) <= max_cc) { - std::string operation_name(operation->description().name); - // Filter kernels by name - bool filtered_by_name = options.operation_names.empty(); - if (!filtered_by_name) { + std::string operation_name(operation->description().name); + // Filter kernels by name + bool filtered_by_name = options.operation_names.empty(); + if (!filtered_by_name) { - for (auto const & op_name : options.operation_names) { + for (auto const & op_name : options.operation_names) { + if (find_string_matches_(op_name, operation_name)) { + filtered_by_name = true; + break; + } + } + } + + for (auto const & op_name : options.excluded_operation_names) { if (find_string_matches_(op_name, operation_name)) { - filtered_by_name = true; + filtered_by_name = false; break; } } - } - for (auto const & op_name : options.excluded_operation_names) { - if (find_string_matches_(op_name, operation_name)) { + // Problems list uses exact match on operation names + if (do_testlist_run && !(all_operations_and_problems[i].first == operation_name)) { filtered_by_name = false; - break; - } - } - - if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) { - continue; - } - - // we have found a kernel match, so increment the counter for match kernels - ++matched_operation_count; - - // A. Initialize configuration - Status status = this->initialize_configuration( - options, - report, - device_context, - operation, - problem_space, - problem); - - if (status == Status::kErrorInternal) { - - // If there was an internal error, consume the CUDA error and move to the next operation. - (void)cudaGetLastError(); - - report.append_result(model_result_); - continue; - } - else if (status != Status::kSuccess) { - // If the workspace could not be initialized for any other reason, continue to - // the next operation. - continue; - } - - if (continue_profiling) { - - if (options.report.print_kernel_before_running) { - std::cout << "Profiling kernel for JUnit test " << options.report.junit_output_path << ": " - << operation_name << std::endl; } - status = this->initialize_workspace( + if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) { + continue; + } + + // we have found a kernel match, so increment the counter for match kernels + ++matched_operation_count; + + // A. Initialize configuration + Status status = this->initialize_configuration( options, report, device_context, @@ -486,7 +481,7 @@ int OperationProfiler::profile_all( // If there was an internal error, consume the CUDA error and move to the next operation. (void)cudaGetLastError(); - report.append_results(results_); + report.append_result(model_result_); continue; } else if (status != Status::kSuccess) { @@ -494,93 +489,123 @@ int OperationProfiler::profile_all( // the next operation. continue; } - } - // - // Profile CUTLASS if it is enabled - // + if (continue_profiling) { - // B. Verify CUTLASS - if (continue_profiling && options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + if (options.report.print_kernel_before_running) { + std::cout << "Profiling kernel for JUnit test " << options.report.junit_output_path << ": " + << operation_name << std::endl; + } - continue_profiling = this->verify_cutlass( - options, - report, - device_context, - operation, - problem_space, - problem); + status = this->initialize_workspace( + options, + report, + device_context, + operation, + problem_space, + problem); - retval |= (not continue_profiling); - } + if (status == Status::kErrorInternal) { + + // If there was an internal error, consume the CUDA error and move to the next operation. + (void)cudaGetLastError(); + + report.append_results(results_); + continue; + } + else if (status != Status::kSuccess) { + // If the workspace could not be initialized for any other reason, continue to + // the next operation. + continue; + } + } + + // + // Profile CUTLASS if it is enabled + // + + // B. Verify CUTLASS + if (continue_profiling && options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + + continue_profiling = this->verify_cutlass( + options, + report, + device_context, + operation, + problem_space, + problem); + + retval |= (not continue_profiling); + } + + if (options.execution_mode == ExecutionMode::kDryRun) { + report.append_results(results_); + results_.clear(); + continue; + } + + // + // C. Optionally save workspace + // + + if (options.verification.save_workspace == SaveWorkspace::kAlways) { + save_workspace( + device_context, + options, + operation->description(), + library::Provider::kCUTLASS); + } + + // + // D. Profile + // + + if (continue_profiling && options.profiling.enabled) { + + continue_profiling = this->profile( + options, + report, + device_context, + operation, + problem_space, + problem); + + // Count op as profiled, even it failed to profile + profiled_operation_count++; + } - if (options.execution_mode == ExecutionMode::kDryRun) { report.append_results(results_); results_.clear(); - continue; + } // if op satisfied compute capacity + + if (!continue_profiling) { + // break out of `for op in manifest` loop and move to next problem + // `for each problem in problem space` conditional check on not continue profiling + break; } + } // for op in manifest - // - // C. Optionally save workspace - // - - if (options.verification.save_workspace == SaveWorkspace::kAlways) { - save_workspace( - device_context, - options, - operation->description(), - library::Provider::kCUTLASS); - } - - // - // D. Profile - // - - if (continue_profiling && options.profiling.enabled) { - - continue_profiling = this->profile( - options, - report, - device_context, - operation, - problem_space, - problem); - - // Count op as profiled, even it failed to profile - profiled_operation_count++; - } - - report.append_results(results_); - results_.clear(); - } // if op satisfied compute capacity - - if (!continue_profiling) { - // break out of `for op in manifest` loop and move to next problem - // `for each problem in problem space` conditional check on not continue profiling - break; + // If we did not find any kernels that match our filters and error_on_no_match was set, report an error + if (options.profiling.error_on_no_match && matched_operation_count <= 0) { + #if !NDEBUG + std::cerr << "Error: No matching kernels found with kernel selection filters [--error_on_no_match]" << std::endl; + #endif + retval |= 1; + // Stop profiling on error no match + continue_profiling = false; } - } // for op in manifest - // If we did not find any kernels that match our filters and error_on_no_match was set, report an error - if (options.profiling.error_on_no_match && matched_operation_count <= 0) { - #if !NDEBUG - std::cerr << "Error: No matching kernels found with kernel selection filters [--error_on_no_match]" << std::endl; - #endif - retval |= 1; - // Stop profiling on error no match - continue_profiling = false; - } + if (options.profiling.error_if_nothing_is_profiled && options.profiling.enabled && profiled_operation_count <= 0) { + #if !NDEBUG + std::cerr << "Error: No kernels profiled found with kernel selection filters [--error_if_nothing_is_profiled]" << std::endl; + #endif + retval |= 1; + // Stop profiling on error no match + continue_profiling = false; + } - if (options.profiling.error_if_nothing_is_profiled && options.profiling.enabled && profiled_operation_count <= 0) { - #if !NDEBUG - std::cerr << "Error: No kernels profiled found with kernel selection filters [--error_if_nothing_is_profiled]" << std::endl; - #endif - retval |= 1; - // Stop profiling on error no match - continue_profiling = false; - } - - } // for each problem in problem space + } // for each problem in problem space + } return retval; } diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index d711c8f1..7fc1d288 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -550,6 +550,9 @@ void Options::Profiling::print_usage(std::ostream &out) const { << " --profiling-enabled= " << " If true, profiling is actually conducted.\n\n" + << " --enable-best-kernel-for-fixed-shape= " + << " If true, iterate through common cluster sizes, raster orders, and swizzle sizes for each kernel.\n\n" + ; } @@ -593,6 +596,9 @@ Options::Verification::Verification(cutlass::CommandLine const &cmdline) { if (enabled) { cmdline.get_cmd_line_argument("verification-required", required, false); } + else { + required = false; + } cmdline.get_cmd_line_argument("epsilon", epsilon, 0.05); @@ -847,6 +853,52 @@ Options::Options(cutlass::CommandLine const &cmdline): for (std::string line; getline(input, line);) { operation_names.push_back(line); } + } else if (cmdline.check_cmd_line_flag("testlist-file")) { + // Problems file is a CSV, where the first column is the kernel name and the rest are the problem arguments + std::string filename; + cmdline.get_cmd_line_argument("testlist-file", filename, {}); + std::ifstream input(filename); + if (!input.good()) { + throw std::runtime_error("failed to open: " + filename); + } + + std::string line; + std::vector col_names; + // Read header line + if (std::getline(input, line)) { + std::stringstream ss(line); + std::string header; + while (std::getline(ss, header, ',')) { + col_names.push_back(header); + } + } + + // Read content lines + while (std::getline(input, line)) { + std::stringstream ss(line); + std::string item; + + size_t colIdx = 0; + std::string operation_name; + + std::unordered_map arguments; + + while (std::getline(ss, item, ',')) { + if (!colIdx) { + // First column is operation name + if (operation_problems.find(item) == operation_problems.end()) { + operation_names.push_back(item); + } + operation_name = item; + } else { + if (colIdx < col_names.size()) { + arguments[col_names[colIdx]] = item; + } + } + colIdx++; + } + operation_problems[operation_name].emplace_back(arguments); + } } if (cmdline.check_cmd_line_flag("ignore-kernels")) { @@ -899,6 +951,10 @@ void Options::print_usage(std::ostream &out) const { << " --ignore-kernels= " << " Excludes kernels whose names match anything in this list.\n\n" + + << " --testlist-file= " + << " A CSV, where each row is a problem, where the first column is the kernel name and the rest are the problem arguments" << end_of_line + << " The column names should match cutlass_profiler cmd line arguments. \n\n" ; // diff --git a/tools/util/include/cutlass/util/command_line.h b/tools/util/include/cutlass/util/command_line.h index b60d868c..c95bd1cb 100644 --- a/tools/util/include/cutlass/util/command_line.h +++ b/tools/util/include/cutlass/util/command_line.h @@ -41,6 +41,7 @@ #include #include #include +#include #include @@ -89,6 +90,16 @@ struct CommandLine { } } + /** + * Constructor to represent a command line from a map of [argument] -> [value] + */ + CommandLine(std::unordered_map& arg_map) { + for (const auto& [key, value] : arg_map) { + keys.push_back(key); + values.push_back(value); + } + } + /** * Checks whether a flag "--" is present in the commandline */ diff --git a/tools/util/include/cutlass/util/cublas_wrappers.hpp b/tools/util/include/cutlass/util/cublas_wrappers.hpp index 8de1aa8e..8ace1e0a 100644 --- a/tools/util/include/cutlass/util/cublas_wrappers.hpp +++ b/tools/util/include/cutlass/util/cublas_wrappers.hpp @@ -51,7 +51,9 @@ // User could potentially define ComplexFloat/ComplexDouble instead of std:: #ifndef BLAM_COMPLEX_TYPES #define BLAM_COMPLEX_TYPES 1 -#include +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(complex) + namespace blam { template using Complex = cuda::std::complex;