diff --git a/CHANGELOG.md b/CHANGELOG.md index 63e3e80e..fc269c8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,9 @@ * CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification. - Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes. - Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors. + - Support for mixed input GEMM kernels on Hopper in the profiler. * New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels. +* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details). * Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture: - [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API. - GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell. @@ -45,12 +47,16 @@ - Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu). - Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu). - [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128. + - A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations. +* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture: + - A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. + - A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). * Documentation updates: - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md) - A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures. - Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture). - - Support grouped GEMM in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details). + - Updates to [profiler documentation](./media/docs/profiler.md) for testing mixed input GEMM kernels on Hopper. ## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11) - [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439). diff --git a/CMakeLists.txt b/CMakeLists.txt index b9de4f96..0b68b435 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,14 +150,14 @@ set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUT set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests") set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest") -set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.") -if (CUTLASS_USE_PACKED_TUPLE) - list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_USE_PACKED_TUPLE=1") - message(STATUS "Make cute::tuple be the new standard-layout tuple type") -elseif() - message(STATUS "Use the original cute::tuple implementation that is _not_ standard-layout") + +if (CUTLASS_ENABLE_TESTS AND CUTLASS_ENABLE_PROFILER) + set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT ON) +else() + set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT OFF) endif() +set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS ${CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT} CACHE BOOL "Enable CUTLASS Profiler-based Unit Tests") +set(CUTLASS_ENABLE_SELF_CONTAINED_INCLUDES_CHECK ON CACHE BOOL "Enable CUTLASS check for self-contained header includes") ################################################################################ @@ -406,7 +406,7 @@ endif() # Warnings-as-error exceptions and warning suppressions for Clang builds if (CUTLASS_CLANG_HOST_COMPILE) - + set(FLAGS_TO_ADD "-Wno-error=implicit-int-conversion" "-Wno-error=pass-failed" @@ -414,13 +414,13 @@ if (CUTLASS_CLANG_HOST_COMPILE) "-Wno-sign-conversion" "-Wno-unused-parameter" ) - + foreach(FLAG ${FLAGS_TO_ADD}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLAG}") list(APPEND CUTLASS_CUDA_NVCC_FLAGS "${FLAG}") list(APPEND CUTLASS_CUDA_CLANG_FLAGS "${FLAG}") endforeach() - + endif() if (NOT MSVC AND CUTLASS_NVCC_KEEP) @@ -486,7 +486,7 @@ if (CUTLASS_CLANG_DEVICE_COMPILE) link_libraries(nvidia::cudart) link_libraries(nvidia::cuda_driver) - + endif() #Report CUDA build flags @@ -561,7 +561,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET) list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real) endif() if(CUTLASS_NVCC_EMBED_PTX AND NOT CUTLASS_CLANG_DEVICE_COMPILE) - # If we're using clang for device compilation, the ptx is inserted + # If we're using clang for device compilation, the ptx is inserted # via another command line option and the `-virtual` flags will cause an error. list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual) endif() @@ -922,7 +922,7 @@ function(cutlass_add_executable_tests NAME TARGET) if (NOT __DO_NOT_LOWERCASE_TEST_NAME) string(TOLOWER "${TESTCASE_NAME}" TESTCASE_NAME) endif() - + # The following rigmarole is needed to deal with spaces and possible quotes in # command line arguments. The options are passed "by reference" as the actual # variable names holding the real options. We then expand these in a way that @@ -1007,46 +1007,51 @@ function(cutlass_generate_profiler_tests NAME) endif() file(STRINGS ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} TEST_LIST) - foreach(TEST IN LISTS TEST_LIST) - + set(TEMP_TEST ${TEST}) if ("${TEST}" MATCHES " *cutlass_profiler.*") - # Generate a flattened name for the test from the test command line. - string(REPLACE "," ";" TEST_NAME_LIST ${TEST}) - list(GET TEST_NAME_LIST 0 TEST) - string(REGEX MATCHALL "[a-zA-Z0-9_=]+" TEST_NAME "${TEST}") - list(FILTER TEST_NAME EXCLUDE REGEX "cutlass_profiler|mode=trace|providers=cutlass") - list(JOIN TEST_NAME "_" TEST_NAME) - string(REGEX REPLACE "_verification_required=(true|false)" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "_verification_providers=device" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "batch_count=" "batch" TEST_NAME "${TEST_NAME}") - string(REPLACE "cluster_m=" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "_cluster_n=" "x" TEST_NAME "${TEST_NAME}") - string(REGEX REPLACE "_cluster_k=[0-9]+" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "cluster_m_fallback=" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "_cluster_n_fallback=" "x" TEST_NAME "${TEST_NAME}") - string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "=" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "_error_on_no_match" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "_error_if_nothing_is_profiled" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "kernels" "" TEST_NAME "${TEST_NAME}") - string(REPLACE "operation" "" TEST_NAME "${TEST_NAME}") + # Generate a flattened name for the test from the test command line. + string(REPLACE "," ";" TEST_NAME_LIST ${TEMP_TEST}) + string(REGEX REPLACE "\\*" "_" TEST_NAME "${TEMP_TEST}") + string(REGEX REPLACE "\\\"\\{\\\"\\\"input_params.*\\{.*\\}\\}\\\"" "" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "\\\"\\{\\\"\\\"input_params.*\\{.*\\}\\}\\\"" "" TEST "${TEST}") + string(REGEX REPLACE "," ";" TEST "${TEST}") + string(REGEX MATCHALL "[a-zA-Z0-9_=]+" TEST_NAME "${TEST_NAME}") + list(FILTER TEST_NAME EXCLUDE REGEX "cutlass_profiler|mode=trace|providers=cutlass") + list(JOIN TEST_NAME "_" TEST_NAME) + string(REGEX REPLACE "_verification_required=(true|false)" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_verification_providers=device" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "batch_count=" "batch" TEST_NAME "${TEST_NAME}") + string(REPLACE "cluster_m=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_cluster_n=" "x" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "_cluster_k=[0-9]+" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "cluster_m_fallback=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_cluster_n_fallback=" "x" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "verification_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "warmup_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "profiling_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "sleep_duration=[0-9]+" "" TEST_NAME "${TEST_NAME}") + string(REGEX REPLACE "profiling_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "=" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_error_on_no_match" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "_error_if_nothing_is_profiled" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "kernels" "" TEST_NAME "${TEST_NAME}") + string(REPLACE "operation" "" TEST_NAME "${TEST_NAME}") - if (__DO_NOT_LOWERCASE_TEST_NAME) - string(TEST_NAME_LOWER "${TEST_NAME}") - else() - string(TOLOWER "${TEST_NAME}" TEST_NAME_LOWER) - endif() + if (NOT __DO_NOT_LOWERCASE_TEST_NAME) + string(TOLOWER "${TEST_NAME}" TEST_NAME) + endif() - # Munge the test command - string(REPLACE "cutlass_profiler" "" TEST "${TEST}") - set(TEST "${TEST}" ${__CUTLASS_PROFILER_EXTRA_OPTIONS} "--junit-output=${TEST_NAME_LOWER}") - set(TEST_COMMAND_${TEST_NAME_LOWER} "${TEST}") - list(APPEND TEST_COMMAND_VARS ${TEST_NAME_LOWER}) + # Munge the test command + string(REPLACE "cutlass_profiler" "" TEST "${TEST}") + set(TEST "${TEST}" ${__CUTLASS_PROFILER_EXTRA_OPTIONS} "--junit-output=${TEST_NAME}") + set(TEST_COMMAND_${TEST_NAME} "${TEST}") + list(APPEND TEST_COMMAND_VARS ${TEST_NAME}) endif() endforeach() @@ -1084,6 +1089,14 @@ if (CUTLASS_ENABLE_TESTS) if (CUTLASS_ENABLE_GTEST_UNIT_TESTS) add_dependencies(test_all test_unit) endif() + if (CUTLASS_ENABLE_PROFILER_UNIT_TESTS AND CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS) + # Generate profiler based unit test + cutlass_generate_profiler_tests( + tup + DEPENDEES test_unit + ) + endif() + endif() if (CUTLASS_INSTALL_TESTS) diff --git a/ACTIVE_DEVELOPERS.md b/CONTRIBUTORS.md similarity index 96% rename from ACTIVE_DEVELOPERS.md rename to CONTRIBUTORS.md index 6ae47b43..1ef06a36 100644 --- a/ACTIVE_DEVELOPERS.md +++ b/CONTRIBUTORS.md @@ -27,7 +27,7 @@ Siyu Liu
Richard Cai
Vikas Gupta
Ethan Yan
-Vijay Thakkar (CUTLASS 3.x founding member)
+Vijay Thakkar (CUTLASS 3.x and CuTe founding member)
Cris Cecka (CuTe and CUTLASS 3.x founding member)
Lawrence Ryan
Qun Song
diff --git a/README.md b/README.md index f9f23c08..ada18b39 100644 --- a/README.md +++ b/README.md @@ -43,23 +43,23 @@ architecture. CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture. For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8. -* Support for new CuTe building blocks specifically for Blackwell architecture: +* Support for new CuTe building blocks specifically for Blackwell SM100 architecture: - [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms. - Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms. - - Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp#L290) across CuTe as a first class data locale. + - Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale. - Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe. - [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms. - Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms. -* Support for new CUTLASS building blocks specifically for Blackwell architecture: +* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture: - Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h) - [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp). - [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp). - Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types. - Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp). - Extensions to testbeds and reference check code for unit tests and CUTLASS profiler. -* Full support for Blackwell kernels in CUTLASS 3.x API: +* Full support for Blackwell SM100 kernels in CUTLASS 3.x API: - [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that - + Implement a new warp-specialization recipe tuned specifically for Blackwell. + + Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture. + Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators. + Support stream-K load balancing for all kernel types everywhere via composable scheduler support. - Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for @@ -73,7 +73,10 @@ For a background on Blackwell's new features, please consult the PTX documentati * CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification. - Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes. - Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors. -* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell + - Support for mixed input GEMM kernels on Hopper in the profiler. +* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels. +* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details). +* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture: - [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API. - GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell. - Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores: @@ -85,6 +88,10 @@ For a background on Blackwell's new features, please consult the PTX documentati - Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu). - Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu). - [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128. + - A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations. +* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture: + - A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes. + - A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu). * Documentation updates: - [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel). - Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md) diff --git a/customConfigs.cmake b/customConfigs.cmake index c86e15be..e39212db 100644 --- a/customConfigs.cmake +++ b/customConfigs.cmake @@ -47,7 +47,7 @@ function(cutlass_generate_kernel_filter_and_testlists_files) COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR} ${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py --generator-target=${__TEST_SET_NAME} - --cuda-version=${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION} + --cuda-version=${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} --architectures=${CUTLASS_NVCC_ARCHS} --kernels=\* --disable-cutlass-package-imports diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index 0c407d34..74aa3614 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -45,7 +45,7 @@ CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can improve performance. Examples: - $ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \ + $ ./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \ --m=2816 --n=3072 --k=16384 \ --save_aux=false --save_amax=false \ --device_scale=false --raster=h --swizzle=2 @@ -119,22 +119,56 @@ using ElementBias = float; using ElementAccumulator = float; // Element type for internal accumulation using ElementBlockScale = float; // Element type for blockscaling during accumulation using ElementCompute = float; // Element type for epilogue computation -using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature -using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag -using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size -using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -constexpr int ScaleMsPerTile = 2; -constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile; +using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... -using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; -using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor +// Given TileShape = Shape<_128,_128,_128>: +// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor) +// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling +// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling +// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling +template +struct GroupScaleConfig { + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size + using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster -using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; -using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + static constexpr int ScaleGranularityM = ScaleGranularityM_; + static constexpr int ScaleGranularityN = ScaleGranularityN_; + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, + "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile, + "FP8 scaling granularity must evenly divide tile shape along N."); + + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>; +}; -using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< +using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>; +using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>; +using GroupScale2D1DConfig = GroupScaleConfig(TileShape_{}), 1>; +using GroupScale2D2DConfig = GroupScaleConfig(TileShape_{}), size<1>(TileShape_{})>; + +template +struct GroupScaleGemm { + using ArchTag = typename ScheduleConfig::ArchTag; + using OperatorClass = typename ScheduleConfig::OperatorClass; + using TileShape = typename ScheduleConfig::TileShape; + using ClusterShape = typename ScheduleConfig::ClusterShape; + using KernelSchedule = typename ScheduleConfig::KernelSchedule; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using FusionOperation = typename ScheduleConfig::FusionOperation; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, @@ -145,7 +179,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui FusionOperation >::CollectiveOp; -using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< + using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, @@ -157,24 +191,38 @@ using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collectiv KernelSchedule >::CollectiveOp; -using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, // Indicates ProblemShape - CollectiveMainloopWithBlockWiseScaling, - CollectiveEpilogue ->; + using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloopWithGroupWiseScaling, + CollectiveEpilogue + >; -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloopWithGroupWiseScaling, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using GemmDefault = cutlass::gemm::device::GemmUniversalAdapter; + using GemmStreamK = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GroupScale1D1DGemm = GroupScaleGemm; +using GroupScale1D2DGemm = GroupScaleGemm; +using GroupScale2D1DGemm = GroupScaleGemm; +using GroupScale2D2DGemm = GroupScaleGemm; // Extract information from Gemm kernel. -using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using EpilogueOutputOp = typename GroupScale1D1DGemm::GemmDefault::EpilogueOutputOp; using ElementScalar = typename EpilogueOutputOp::ElementScalar; using ElementAmax = typename EpilogueOutputOp::ElementAmax; using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; -using StrideA = typename Gemm::GemmKernel::StrideA; -using StrideB = typename Gemm::GemmKernel::StrideB; -using StrideC = typename Gemm::GemmKernel::StrideC; -using StrideD = typename Gemm::GemmKernel::StrideD; +using StrideA = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideA; +using StrideB = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideB; +using StrideC = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideC; +using StrideD = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideD; using StrideAux = StrideD; constexpr bool IsDFp8 = @@ -185,9 +233,6 @@ constexpr bool IsAuxFp8 = cute::is_same_v or cute::is_same_v; -static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, - "FP8 scaling granularity must evenly divide tile shape along M."); - static_assert(cute::is_same_v, "ElementAccumulator and ElementBlockScale should be same datatype"); @@ -347,13 +392,18 @@ struct Result } /// Initialize operands to be used in the GEMM and reference GEMM +template void initialize(const Options &options) { + using TileShape = typename GroupScaleConfig::TileShape; + const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile; + const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile; + // Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access. - auto blockscale_n = cute::get<1>(blockscale_shape); + auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access. auto blockscale_k = cute::get<2>(blockscale_shape); stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); @@ -362,18 +412,16 @@ void initialize(const Options &options) { stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); stride_aux = stride_D; - - auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k); - auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l); + auto groupscale_b_coord = cutlass::make_Coord(groupscale_n * options.l, blockscale_k); tensor_A.resize(a_coord); - blockscale_tensor_A.resize(groupscale_a_coord); tensor_B.resize(b_coord); - blockscale_tensor_B.resize(blockscale_b_coord); + blockscale_tensor_A.resize(groupscale_a_coord); + blockscale_tensor_B.resize(groupscale_b_coord); tensor_C.resize(c_coord); tensor_D.resize(c_coord); tensor_ref_D.resize(c_coord); @@ -393,7 +441,7 @@ void initialize(const Options &options) { #if 0 // Dump blockscaled tensors std::cout << "blockscale_tensor_A: " << groupscale_a_coord << std::endl; std::cout << blockscale_tensor_A.host_view() << "\n"; - std::cout << "blockscale_tensor_B: " << blockscale_b_coord << std::endl; + std::cout << "blockscale_tensor_B: " << groupscale_b_coord << std::endl; std::cout << blockscale_tensor_B.host_view() << "\n"; #endif @@ -441,21 +489,26 @@ void initialize(const Options &options) { if (IsDFp8 && options.save_amax) { abs_max_D.resize(cutlass::make_Coord(1)); + initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0); abs_max_D.sync_device(); reference_abs_max_D.resize(cutlass::make_Coord(1)); + initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0); } if (IsAuxFp8 && options.save_aux && options.save_amax) { abs_max_aux.resize(cutlass::make_Coord(1)); + initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0); abs_max_aux.sync_device(); reference_abs_max_aux.resize(cutlass::make_Coord(1)); + initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0); } } /// Populates a Gemm::Arguments structure from the given commandline options -typename Gemm::Arguments args_from_options(const Options &options) +template +GemmArguments args_from_options(const Options &options) { - typename Gemm::Arguments arguments{ + GemmArguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {options.m, options.n, options.k, options.l}, {tensor_A.device_data(), @@ -513,14 +566,15 @@ typename Gemm::Arguments args_from_options(const Options &op return arguments; } -bool verify(const Options &options) { +/// Don't know why the compiler does not like verify() being templated... +bool verify(const Options &options, const int ScaleMsPerTile, const int ScaleNsPerTile) { // // Compute reference output // // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); - auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{}))); + auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{}))); auto blockscale_m = cute::get<0>(blockscale_shape); auto blockscale_n = cute::get<1>(blockscale_shape); auto blockscale_k = cute::get<2>(blockscale_shape); @@ -565,8 +619,8 @@ bool verify(const Options &options) { ); auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(), cute::make_layout( - cute::make_shape(blockscale_n, blockscale_k, options.l), - cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k) + cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l), + cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile) ) ); @@ -575,7 +629,7 @@ bool verify(const Options &options) { cutlass::reference::host::GettMainloopParams mainloop_params{ + TileShape_> mainloop_params{ A, B, // Operand Tensors blockscale_A, blockscale_B // Groupwise scaling Tensors }; @@ -641,16 +695,22 @@ bool verify(const Options &options) { } /// Execute a given example GEMM computation -template +template int run(Options &options) { - initialize(options); + using TileShape = typename GroupScaleConfig::TileShape; + const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM; + const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN; + const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile; + const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile; + + initialize(options); // Instantiate CUTLASS kernel depending on templates Gemm gemm; // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm - auto arguments = args_from_options(options); + auto arguments = args_from_options(options); // Using the arguments, query for extra workspace required for matrix multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -669,7 +729,7 @@ int run(Options &options) // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; - result.passed = verify(options); + result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile); std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; @@ -683,6 +743,7 @@ int run(Options &options) GpuTimer timer; timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); CUTLASS_CHECK(gemm.run()); } timer.stop(); @@ -702,9 +763,13 @@ int run(Options &options) } std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; + std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; + std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; std::cout << " GFLOPS: " << result.gflops << std::endl; + fflush(stdout); } return 0; @@ -753,7 +818,27 @@ int main(int argc, char const **args) { // #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) - run(options); + std::cout << "Basic split-K GEMM kernel" << std::endl; + run(options); + std::cout << std::endl; + run(options); + std::cout << std::endl; + run(options); + std::cout << std::endl; + run(options); + std::cout << std::endl; + + std::cout << std::endl; + + std::cout << "StreamK GEMM kernel" << std::endl; + run(options); + std::cout << std::endl; + run(options); + std::cout << std::endl; + run(options); + std::cout << std::endl; + run(options); + std::cout << std::endl; #endif return 0; diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h index cb3ff022..e9809f6b 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h @@ -220,10 +220,12 @@ void gett_mainloop( int64_t block_m = m / kBlockM; int64_t block_n = n / kBlockN; cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l); - cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l); + cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l); const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape()); - assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape())); + const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape()); + assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape())); + assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape())); // Compute on this k-block for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { @@ -231,7 +233,7 @@ void gett_mainloop( // Load Blockwise scaling factor from blockscale Tensors for B int64_t block_k = k / kBlockK; cute::Tensor scale_a = blockscale_A(_, block_k); - ElementBlockScaleB scale_b = blockscale_B[block_k]; + cute::Tensor scale_b = blockscale_B(_, block_k); // Load A ElementAccumulator a_frag[kBlockM]; @@ -268,8 +270,10 @@ void gett_mainloop( // (c) Update permanent (accu) if ((k+1) % kBlockK == 0) { for (int m_b = 0; m_b < kBlockM; ++m_b) { + auto scale_a_m_b = scale_a[m_b / ScaleGranularityM]; for (int n_b = 0; n_b < kBlockN; ++n_b) { - ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a[m_b / ScaleGranularityM] * scale_b; + auto scale_b_n_b = scale_b[n_b / ScaleGranularityN]; + ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b; acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; acc_temp[m_b][n_b] = ElementAccumulator(0); } diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu index b22d8305..c1978c32 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu @@ -32,7 +32,19 @@ /*! \file \brief - NOTE: Write docu + Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. + See 55_hopper_int4_bf16_gemm.cu for more details about W4A16 GEMMs with layout shuffling. + + Limitations: + 1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported. + + To run this example: + + $ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. */ #include diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu index cc0494ec..07ff66b3 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu @@ -32,7 +32,19 @@ /*! \file \brief - NOTE: Write docu + Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. + See 55_hopper_int4_fp8_gemm.cu for more details about W4A8 GEMMs with lookup table. + + Limitations: + 1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported. + + To run this example: + + $ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. */ #include diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu index 883d8cbf..ffeb233e 100644 --- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu @@ -31,7 +31,19 @@ /*! \file \brief - NOTE: Write docu + Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. + See 55_hopper_mixed_dtype_gemm.cu for more details about Mixed-input GEMMs. + + Limitations: + 1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported. + + To run this example: + + $ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. */ #include diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md new file mode 100644 index 00000000..272d36e5 --- /dev/null +++ b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md @@ -0,0 +1,14 @@ +This example extends Example 55 to support Grouped GEMMs in CUTLASS. + +## High level overview + +This example shows how to perform Grouped GEMMs on Hopper when A and B have different types. In the Grouped GEMM, multiple GEMMs with potentially different problem shapes can be excetued in a batch. The interface is similar to the standard mixed-input GEMM presented in Example 55, with a few noteworthy differences: +- inside the collective builder, replace the layout types with layout pointer types. +- in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B. +- if scales and zero-points are included, also pass the array of their strides in the arguments. + +Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs. + +## Upcoming features + +Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed. diff --git a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu index 39123cac..3cee6caf 100644 --- a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu +++ b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu @@ -115,15 +115,11 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O using MmaTileShape_MNK = Shape<_256,_128,_64>; // Shape of the threadblocks in a cluster using ClusterShape_MNK = Shape<_2,_2,_1>; -// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 -using AtomThrShape_MNK = Shape<_2, _1, _1>; -// Shape of the tile computed by each SM -using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); // Build the epilogue using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - PerSmTileShape_MNK, ClusterShape_MNK, + MmaTileShape_MNK, ClusterShape_MNK, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, diff --git a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu index 0b1758b9..69a36310 100644 --- a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu +++ b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu @@ -131,17 +131,13 @@ using ElementAmax = float; using MmaTileShape_MNK = Shape<_256,_128,_64>; // Shape of the threadblocks in a cluster using ClusterShape_MNK = Shape<_2,_2,_1>; -// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 -using AtomThrShape_MNK = Shape<_2, _1, _1>; -// Shape of the tile computed by each SM -using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); using FusionOp = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< LayoutC, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - PerSmTileShape_MNK, ClusterShape_MNK, + MmaTileShape_MNK, ClusterShape_MNK, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, diff --git a/examples/70_blackwell_gemm/CMakeLists.txt b/examples/70_blackwell_gemm/CMakeLists.txt index d88a8c56..cb401e3a 100644 --- a/examples/70_blackwell_gemm/CMakeLists.txt +++ b/examples/70_blackwell_gemm/CMakeLists.txt @@ -28,7 +28,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 70_blackwell_fp16_gemm 70_blackwell_fp16_gemm.cu diff --git a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu index 6712d7a9..427af254 100644 --- a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu +++ b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu @@ -184,12 +184,8 @@ struct ExampleRunner { std::is_same_v || // Auto schedule will try to select 2sm cluster MMA based on cluster M std::is_same_v && size<0>(ClusterShapeMNK{}) % 2 == 0; - // The MNK layout of CTAs within a cluster MMA - using AtomThrMNK = std::conditional_t, Shape<_1,_1,_1>>; // The MMA tile used by the mainloop collective. Blackwell 1sm MMA supports up to MMA tile M = 128, 2sm MMA supports up to MMA tile M = 256 using MmaTileMNK = std::conditional_t, Shape<_128,_128,_64>>; - // The Output tile used by the epilogue collective - using OutputTileMNK = decltype(shape_div(MmaTileMNK{}, AtomThrMNK{})); // 16B alignment lets us use TMA static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; @@ -220,7 +216,7 @@ struct ExampleRunner { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputTileMNK, ClusterShapeMNK, + MmaTileMNK, ClusterShapeMNK, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, @@ -503,20 +499,20 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and 3 mainloop stages", passed); // 1SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules - ExampleRunner runner_2; + ExampleRunner runner_2; passed = runner_2.run(options, hw_info); - print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed); + print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized1Sm epilogue schedule", passed); // 1SM cluster MMA mainloop schedules can also be used with 1SM TMA epilogue schedules // 1SM cluster MMA mainloop schedules will not work with 2SM TMA epilogue schedules ExampleRunner runner_3; passed = runner_3.run(options, hw_info); - print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed); + print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with TmaWarpSpecialized1Sm epilogue schedule", passed); // 2SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules - ExampleRunner runner_4; + ExampleRunner runner_4; passed = runner_4.run(options, hw_info); - print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed); + print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with NoSmemWarpSpecialized2Sm epilogue schedule", passed); // 2SM cluster MMA mainloop schedules can also be used with 2SM TMA epilogue schedules // 2SM cluster MMA mainloop schedules will not work with SM TMA epilogue schedules @@ -556,11 +552,11 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN // Blackwell direct store epilogue schedule supports custom EVTs and named fusion operations as well (not supported for pre-Blackwell kernels) ExampleRunner< cutlass::gemm::KernelTmaWarpSpecialized1SmSm100, - cutlass::epilogue::NoSmemWarpSpecialized, + cutlass::epilogue::NoSmemWarpSpecialized1Sm, cutlass::gemm::collective::StageCountAuto, UseCustomEVT> runner_9; passed = runner_9.run(options, hw_info); - print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue and custom EVT", passed); + print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized1Sm epilogue and custom EVT", passed); #endif diff --git a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt index 5bac6494..a326f461 100644 --- a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt +++ b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # Both filenames are shorter to avoid MAX_PATH issues on Windows. -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 71_blackwell_gemm_with_collective_builder 71_blackwell_gemm_with_collective_builder.cu diff --git a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu index ec597966..f7e12fbf 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu @@ -117,11 +117,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O // Kernel Perf config using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster -using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - PerSmTileShape_MNK, ClusterShape, + MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutCTag, AlignmentC, diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu index cefa3e92..2719cab9 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu @@ -121,7 +121,6 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O // Kernel Perf config using MmaTileShape = Shape<_128,_128,_256>; // MMA's tile size using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster -using PerSmTileShape_MNK = Shape<_128,_128,_256>; // Threadblock-level tile size constexpr int InputSFVectorSize = 16; constexpr int OutputSFVectorSize = InputSFVectorSize; @@ -137,7 +136,7 @@ using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - PerSmTileShape_MNK, ClusterShape, + MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutCTag, AlignmentC, diff --git a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu index b73f2c94..2784d050 100644 --- a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu +++ b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu @@ -118,11 +118,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O // Kernel Perf config using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster -using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - PerSmTileShape_MNK, ClusterShape, + MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutCTag, AlignmentC, diff --git a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt index fa80c184..eaeb6600 100644 --- a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt +++ b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt @@ -28,7 +28,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 72a_blackwell_nvfp4_bf16_gemm 72a_blackwell_nvfp4_bf16_gemm.cu diff --git a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt index 0d0f7757..a4a18324 100644 --- a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt +++ b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt @@ -28,7 +28,7 @@ -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 73_blackwell_gemm_preferred_cluster blackwell_gemm_preferred_cluster.cu diff --git a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu index fb62e844..19c4efd1 100644 --- a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu +++ b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu @@ -129,27 +129,22 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O // MMA and Cluster Tile Shapes // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0 using MmaTileShape_MNK = Shape<_256,_128,_64>; -// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 -using AtomThrShape_MNK = Shape<_2, _1, _1>; -// Shape of the tile computed by each SM -using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); // Shape of the cluster set to to indicate dynamic cluster shape using ClusterShape_MNK = Shape; // When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that // lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2 -// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1> -using KernelSchedule = cute::conditional_t; +// To use tcgen05 MMA cta_group = 2, users must explicitly use 2sm builder schedules +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - PerSmTileShape_MNK, ClusterShape_MNK, + MmaTileShape_MNK, ClusterShape_MNK, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, - cutlass::epilogue::collective::EpilogueScheduleAuto + EpilogueSchedule >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< diff --git a/examples/74_blackwell_gemm_streamk/CMakeLists.txt b/examples/74_blackwell_gemm_streamk/CMakeLists.txt index 618561f5..5a378241 100644 --- a/examples/74_blackwell_gemm_streamk/CMakeLists.txt +++ b/examples/74_blackwell_gemm_streamk/CMakeLists.txt @@ -29,7 +29,7 @@ -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 74_blackwell_gemm_streamk blackwell_gemm_streamk.cu diff --git a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu index bb99fa4a..8f6def99 100644 --- a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu +++ b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu @@ -133,22 +133,17 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O // MMA and Cluster Tile Shapes // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0 using MmaTileShape_MNK = Shape<_256,_128,_64>; -// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2 -using AtomThrShape_MNK = Shape<_2, _1, _1>; -// Shape of the tile computed by each SM -using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{})); // Shape of the cluster set to to indicate dynamic cluster shape using ClusterShape_MNK = Shape; -// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that +// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that // lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2 -// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1> -using KernelSchedule = cute::conditional_t; +// To use tcgen05 MMA cta_group = 2, users must explicitly use 2sm builder schedules +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - PerSmTileShape_MNK, ClusterShape_MNK, + MmaTileShape_MNK, ClusterShape_MNK, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu index 520d8cee..1d8db6e2 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu @@ -121,32 +121,25 @@ using StageCountType = cutlass::gemm::collective::StageCountAuto; // S // Runtime Cluster Shape using ClusterShape = Shape; -// For Static Cluster Shape: -// using ClusterShape = Shape<_2,_1,_1>; // for example -// using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config -// using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder -// using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder // Different configs for 1SM and 2SM MMA kernel struct MMA1SMConfig { using MmaTileShape = Shape<_128,_256,Int<128 / sizeof(ElementA)>>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch - using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{})); }; struct MMA2SMConfig { using MmaTileShape = Shape<_256,_256,Int<128 / sizeof(ElementA)>>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch - using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{})); }; template struct GivenGemmSchedule { using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - typename ScheduleConfig::OutputTileShape, ClusterShape, + typename ScheduleConfig::MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutC *, AlignmentC, diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu index fa65e508..ee697135 100644 --- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu +++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu @@ -143,31 +143,23 @@ using StageCountType = cutlass::gemm::collective::StageCountAuto; // S // Runtime Cluster Shape using ClusterShape = Shape; -/* // For Static Cluster Shape: -use ClusterShape = Shape<_2,_1,_1> for example -using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config -using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder -using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder -*/ // Different configs for 1SM and 2SM MMA kernel struct MMA1SMConfig { using MmaTileShape = Shape<_128,_256,_256>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch - using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{})); }; struct MMA2SMConfig { using MmaTileShape = Shape<_256,_256,_256>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100; // Kernel to launch using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch - using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{})); }; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, EpilogueOperatorClass, - typename MMA1SMConfig::OutputTileShape, ClusterShape, + typename MMA1SMConfig::MmaTileShape, ClusterShape, Shape<_128,_64>, ElementAccumulator, ElementAccumulator, ElementC, LayoutC *, AlignmentC, @@ -195,7 +187,7 @@ using Gemm = Gemm1SM; using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, EpilogueOperatorClass, - typename MMA2SMConfig::OutputTileShape, ClusterShape, + typename MMA2SMConfig::MmaTileShape, ClusterShape, Shape<_128,_64>, ElementAccumulator, ElementAccumulator, ElementC, LayoutC *, AlignmentC, diff --git a/examples/75_blackwell_grouped_gemm/CMakeLists.txt b/examples/75_blackwell_grouped_gemm/CMakeLists.txt index 2da2d4c4..0ce48662 100644 --- a/examples/75_blackwell_grouped_gemm/CMakeLists.txt +++ b/examples/75_blackwell_grouped_gemm/CMakeLists.txt @@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0) set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 75_blackwell_grouped_gemm 75_blackwell_grouped_gemm.cu diff --git a/examples/76_blackwell_conv/CMakeLists.txt b/examples/76_blackwell_conv/CMakeLists.txt index 8d31d743..e4042aa6 100644 --- a/examples/76_blackwell_conv/CMakeLists.txt +++ b/examples/76_blackwell_conv/CMakeLists.txt @@ -28,7 +28,7 @@ -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 76_blackwell_conv_fprop 76_blackwell_conv_fprop.cu diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index c840d8ba..90b47387 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -49,7 +49,7 @@ set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap) set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only) if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang"))) - if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") + if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 77_blackwell_fmha_fp8 77_blackwell_fmha.cu diff --git a/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu index 48e1da6c..f50e85b4 100644 --- a/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu +++ b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu @@ -106,25 +106,23 @@ using ArchTag = cutlass::arch::Sm100; // T using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag // Kernel Perf config -using ClusterTileShape = Shape<_256,_128,_16>; // Cluster-level tile shape -using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster -using CtaTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // Threadblock-level tile shape -using MmaTileShape = Shape<_256,_128,_16>; // Mma instruction shape +using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster +using MmaTileShape = Shape<_256,_128,_16>; // Mma instruction shape // Build the epilogue using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - CtaTileShape, ClusterShape, + MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, - cutlass::epilogue::NoSmemWarpSpecialized + cutlass::epilogue::NoSmemWarpSpecialized2Sm >::CollectiveOp; // Build the mainloop // Note: Emulated BF16x9 kernels need to manually specify a mainloop schedule and cannot use KernelScheduleAuto -using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecializedFastFP32SmemSm100; +using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmFastFP32SmemSm100; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, diff --git a/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt b/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt index 1b36a4fd..6fcbd062 100644 --- a/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt +++ b/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt @@ -28,7 +28,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100") +if (CUTLASS_NVCC_ARCHS MATCHES 100a) cutlass_example_add_executable( 78_blackwell_emulated_bf16x9_gemm 78_blackwell_emulated_bf16x9_gemm.cu diff --git a/examples/README.md b/examples/README.md index ec39bf22..68bf7077 100644 --- a/examples/README.md +++ b/examples/README.md @@ -254,7 +254,7 @@ Blackwell SM100 GEMM example demonstrating compatible mainloop+epilogue builder schedules and epilogue visitor tree (EVT) construction -* [72a_blackwell_narrow_precision_gemm](72a_blackwell_narrow_precision_gemm) +* [72_blackwell_narrow_precision_gemm](72_blackwell_narrow_precision_gemm/) Block-scaled dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs. @@ -278,6 +278,10 @@ Blackwell SM100 FMHA kernel +* [78_blackwell_emulated_bf16x9_gemm](78_blackwell_emulated_bf16x9_gemm) + + Blackwell SM100 FastFP32 (using BF16 to emulate SGEMM) kernel + # CuTe - Programming Examples Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/). diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index e1950b92..b97fc4c8 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -86,5 +86,3 @@ #define CUTE_ARCH_FLOAT2_MATH_ENABLED #endif - - diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index a157008c..f5f50647 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -208,6 +208,7 @@ to_CUtensorMapDataType() { if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else diff --git a/include/cute/arch/mma_sm100_umma.hpp b/include/cute/arch/mma_sm100_umma.hpp index d954544f..1f74223b 100644 --- a/include/cute/arch/mma_sm100_umma.hpp +++ b/include/cute/arch/mma_sm100_umma.hpp @@ -956,7 +956,7 @@ template struct SM100_MMA_MXF4_SS { - static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster OMMA."); + static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster MMA."); static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4_SS N-mode size should be a multiple of 8 between 8 and 256."); static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_SS Vector size can only be 16 or 32."); diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp index cd344fd5..6a767ae3 100644 --- a/include/cute/atom/copy_traits_sm100.hpp +++ b/include/cute/atom/copy_traits_sm100.hpp @@ -45,7 +45,6 @@ namespace cute { - template <> struct Copy_Traits { diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp index ea3eaf72..a431fc4a 100644 --- a/include/cute/container/array.hpp +++ b/include/cute/container/array.hpp @@ -372,7 +372,7 @@ void swap(array& a, array& b) /// @return A cute::array of the elements of @c t in reverse order. template CUTE_HOST_DEVICE constexpr -cute::array reverse(cute::array const& t) +cute::array reverse(cute::array const& t) { if constexpr (N == 0u) { return t; @@ -441,17 +441,6 @@ struct tuple_element> using type = T; }; -template -struct tuple_size const> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element const> -{ - using type = T; -}; - } // end namespace CUTE_STL_NAMESPACE #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD @@ -477,16 +466,5 @@ struct tuple_element> using type = T; }; -template -struct tuple_size const> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element const> -{ - using type = T; -}; - } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index d6b1fafb..38da7ace 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -611,17 +611,6 @@ struct tuple_element> using type = T; }; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> -{ - using type = T; -}; - } // end namespace CUTE_STL_NAMESPACE #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD @@ -647,16 +636,5 @@ struct tuple_element> using type = T; }; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> -{ - using type = T; -}; - } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/packed_tuple.hpp b/include/cute/container/packed_tuple.hpp deleted file mode 100644 index a7a1c3b2..00000000 --- a/include/cute/container/packed_tuple.hpp +++ /dev/null @@ -1,254 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include -#include -#include -#include - -namespace cute { - -namespace detail { - -// Empty Structure Optimization -template -struct ESO; - -template -static constexpr bool is_first_empty_v = cute::is_empty::value; -template -static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...); - -template -using ESO_t = ESO, is_rest_empty_v, T...>; - -// Empty First and Empty Rest... -template -struct ESO { - CUTE_HOST_DEVICE constexpr - ESO() {} - - CUTE_HOST_DEVICE constexpr - ESO(First const&, Rest const&...) {} -}; - -// NonEmpty First and Empty Rest... -template -struct ESO { - CUTE_HOST_DEVICE constexpr - ESO() : first_{} {} - - CUTE_HOST_DEVICE constexpr - ESO(First const& first, Rest const&...) : first_{first} {} - - First first_; -}; - -// Empty First and NonEmpty Rest... -template -struct ESO { - CUTE_HOST_DEVICE constexpr - ESO() : rest_{} {} - - CUTE_HOST_DEVICE constexpr - ESO(First const&, Rest const&... rest) : rest_{rest...} {} - - ESO_t rest_; -}; - -// NonEmpty T and NonEmpty Rest... -template -struct ESO { - CUTE_HOST_DEVICE constexpr - ESO() : first_{}, rest_{} {} - - CUTE_HOST_DEVICE constexpr - ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {} - - First first_; - ESO_t rest_; -}; - -// Get Nth value from ESO -template -CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO const& s) { - if constexpr (N == 0) { - if constexpr (F) { return T{}; } - else { return static_cast(s.first_); } - } else { - if constexpr (R) { return cute::tuple_element_t>{}; } - else { return getv(s.rest_); } - } -} - -template -CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO& s) { - if constexpr (N == 0) { - if constexpr (F) { return T{}; } - else { return static_cast(s.first_); } - } else { - if constexpr (R) { return cute::tuple_element_t>{}; } - else { return getv(s.rest_); } - } -} - -template -CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO&& s) { - if constexpr (N == 0) { - if constexpr (F) { return T{}; } - else { return static_cast(s.first_); } - } else { - if constexpr (R) { return cute::tuple_element_t>{}; } - else { return getv(static_cast&&>(s.rest_)); } - } -} - -// findt: Implementation detail of cute::find. -// If X is the first template argument of the tuple, findt returns C. - -template -CUTE_HOST_DEVICE constexpr -auto -findt(ESO const& t) noexcept -{ - if constexpr (cute::is_same_v) { - return C{}; - } - else { - static_assert(sizeof...(Rest) != 0, - "The type does not appear in the argument list of the tuple."); - if constexpr (IsRestEmpty) { - // The rest is empty, so creating an instance of it is cheap. - return cute::detail::findt(ESO_t{}); - } - else { - return cute::detail::findt(t.rest_); - } - } -} - -} // end namespace detail - -// packed_tuple is a tuple type that is a standard-layout type -// whenever all of its template arguments are standard layout types: -// (cute::is_standard_layout_v && ...) implies (cute::is_standard_layout_v>) - -template -struct packed_tuple : detail::ESO_t -{ - CUTE_HOST_DEVICE constexpr - packed_tuple() {} - - CUTE_HOST_DEVICE constexpr - packed_tuple(T const&... ts) - : detail::ESO_t(ts...) - {} -}; - -template <> -struct packed_tuple<> {}; - -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -get(packed_tuple const& t) { - static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(t); -} - -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -get(packed_tuple& t) { - static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(t); -} - -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -get(packed_tuple&& t) { - static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(static_cast&&>(t)); -} - -template -CUTE_HOST_DEVICE constexpr -packed_tuple -make_packed_tuple(T const&... t) -{ - return {t...}; -} - -// Returns the position of type X (as a static integer) in the tuple -// type's argument list. X must be unique in the argument list. -template -CUTE_HOST_DEVICE constexpr -auto -find(packed_tuple const& t) noexcept -{ - return detail::findt(t); -} - -} // end namespace cute - -namespace CUTE_STL_NAMESPACE -{ - -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> - : CUTE_STL_NAMESPACE::tuple_element> -{}; - -} // end namespace CUTE_STL_NAMESPACE - -#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD -namespace std { - -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> - : CUTE_STL_NAMESPACE::tuple_element> -{}; - -} // end namespace std -#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index dab8621e..e62cfe16 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -37,169 +37,183 @@ #include #include -#if defined(CUTLASS_USE_PACKED_TUPLE) -# include -#endif - //#include // Advanced optimizations -// cute::tuple is like std::tuple, with two differences. +// cute::tuple is like std::tuple, with differences: // // 1. It works on both host and device. // 2. Its template arguments must be semiregular types. +// 3. It is always a standard-layout type if all of its template arguments are standard-layout types. +// 4. It is always an empty type if all of its template arguments are empty types. // // Semiregular types are default constructible and copyable. // They include "value types" like int or float, // but do _not_ include references like int& or float&. // (See std::tie for an example of a tuple of references.) // -// If the template arguments of cute::tuple are all empty types (in -// the sense of std::is_empty_v), then the cute::tuple is also an -// empty type. Furthermore, if CUTLASS_USE_PACKED_TUPLE is defined, -// cute::tuple is always a standard-layout type if all of its template -// arguments are standard-layout types. - -namespace cute -{ - -#if defined(CUTLASS_USE_PACKED_TUPLE) - -template -using tuple = packed_tuple; - -#else - -namespace detail -{ - -// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of +// Standard-layout types preserve ABI across host-device boundaries. +// They are safe to use as device kernel parameters. +// +// The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of // the conversion SFINAE, special overloading, and avoiding cvref template types. // // Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. -// EBO stands for "empty base optimization." +namespace cute +{ + +namespace detail +{ + +// ESO stands for "empty structure optimization." // We use this technique to ensure that cute::tuple -// doesn't need to waste space storing any template arguments -// of cute::tuple that have no data (like integral_constant). -// Otherwise, cute::tuple would need to spend at least 1 byte -// for each of its template arguments. -// -// This is one way in which cute::tuple differs from std::tuple. +// doesn't waste space storing template arguments that have no data (like integral_constant). // Empty types in the template argument list are not even constructed, -// and do not have unique element addresses. In fact, they are not -// even members of the tuple or stored in any way. Calling `get` +// and do not have unique element addresses. Calling `get` // constructs and returns an instance of an empty type on demand. -// -// EBO always "holds" a single value of type T. -// N is like an array index that TupleBase uses -// to access the desired tuple element. -template ::value> -struct EBO; -template -CUTE_HOST_DEVICE constexpr C findt(EBO const&) -{ return {}; } +template +struct ESO; -// Specialization for types T that have no data; -// the "static tuple leaf." Valid T here include -// integral_constant, Int, -// and any other semiregular type -// for which std::is_empty_v is true. -template -struct EBO -{ +template +static constexpr bool is_first_empty_v = cute::is_empty::value; +template +static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...); + +template +using ESO_t = ESO, is_rest_empty_v, T...>; + +// Empty First and Empty Rest... +template +struct ESO { CUTE_HOST_DEVICE constexpr - EBO() {} + ESO() {} CUTE_HOST_DEVICE constexpr - EBO(T const&) {} + ESO(First const&, Rest const&...) {} }; -template -CUTE_HOST_DEVICE constexpr T getv(EBO const&) -{ return {}; } - -// This is a work around approach to solve a shared memory misalign issue (https://github.com/NVIDIA/cutlass/issues/1250). -// Will remove this work around implementation once the corresponding fix in compiler is released. -struct dummy_EBO_base {}; - -// Specialization for types T that are not empty; -// the "dynamic tuple leaf." Valid T here include int, -// any other integral or floating-point type, -// or any semiregular type for which std::is_empty_v is false. -template -struct EBO : private dummy_EBO_base -{ +// NonEmpty First and Empty Rest... +template +struct ESO { CUTE_HOST_DEVICE constexpr - EBO() : t_{} {} + ESO() : first_{} {} CUTE_HOST_DEVICE constexpr - EBO(T const& t) : t_{t} {} + ESO(First const& first, Rest const&...) : first_{first} {} - T t_; + First first_; }; -template -CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) -{ return x.t_; } - -template -CUTE_HOST_DEVICE constexpr T& getv(EBO& x) -{ return x.t_; } - -template -CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) -{ return cute::move(x.t_); } - -template -struct TupleBase; - -// Base class of cute::tuple binds each element to an index -// by inheriting from EBO for each (i, t) in (I..., T...). -// The storage (for nonempty t) lives in the base classes. -template -struct TupleBase, T...> - : EBO... -{ +// Empty First and NonEmpty Rest... +template +struct ESO { CUTE_HOST_DEVICE constexpr - TupleBase() {} + ESO() : rest_{} {} CUTE_HOST_DEVICE constexpr - TupleBase(T const&... t) : EBO(t)... {} + ESO(First const&, Rest const&... rest) : rest_{rest...} {} + + ESO_t rest_; }; +// NonEmpty T and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{}, rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {} + + First first_; + ESO_t rest_; +}; + +// Get Nth value from ESO +template +CUTE_HOST_DEVICE constexpr +cute::enable_if_t>>::value, + cute::tuple_element_t>> +getv(ESO const&) +{ + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +cute::enable_if_t>>::value, + cute::tuple_element_t> const&> +getv(ESO const& s) +{ + if constexpr (N == 0) { + return static_cast(s.first_); + } else { + return getv(s.rest_); + } +} + +template +CUTE_HOST_DEVICE constexpr +cute::enable_if_t>>::value, + cute::tuple_element_t> &> +getv(ESO& s) +{ + if constexpr (N == 0) { + return static_cast(s.first_); + } else { + return getv(s.rest_); + } +} + +template +CUTE_HOST_DEVICE constexpr +cute::enable_if_t>>::value, + cute::tuple_element_t> &&> +getv(ESO&& s) +{ + if constexpr (N == 0) { + return static_cast(s.first_); + } else { + return getv(static_cast&&>(s.rest_)); + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +findt(ESO const& t) noexcept +{ + if constexpr (cute::is_same_v) { + return C{}; + } else + if constexpr (sizeof...(Rest) == 0) { + return C{}; + } else + if constexpr (IsRestEmpty) { + return cute::detail::findt(ESO_t{}); + } else { + return cute::detail::findt(t.rest_); + } +} + } // end namespace detail -// Attempting to use the following commented-out alias -// in the declaration of `struct tuple` causes MSVC 2022 build errors. -// -//template -//using TupleBase = detail::TupleBase, T...>; - -// This is the actual cute::tuple class. -// The storage (if any) lives in TupleBase's EBO base classes. -// -// Inheriting from the above alias TupleBase -// causes MSVC 2022 build errors when assigning one tuple to another: -// In summary: this is verbose as a work-around for MSVC build errors. template -struct tuple : detail::TupleBase, T...> +struct tuple : detail::ESO_t { CUTE_HOST_DEVICE constexpr tuple() {} CUTE_HOST_DEVICE constexpr - tuple(T const&... t) : detail::TupleBase, T...>(t...) {} + tuple(T const&... t) : detail::ESO_t(t...) {} }; template <> -struct tuple<> -{}; - -// -// get for cute::tuple (just like std::get for std::tuple) -// +struct tuple<> {}; +// Returns the element in the ith position of the tuple template CUTE_HOST_DEVICE constexpr decltype(auto) @@ -224,25 +238,19 @@ decltype(auto) get(tuple&& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); - return detail::getv(static_cast&&>(t)); + return detail::getv(static_cast&&>(t)); } -// -// find a type X within a cute::tuple -// Requires X to be unique in tuple -// Returns a static integer -// - +// Returns the position of type X (as a static integer) in the tuple +// type's argument list. X must be unique in the argument list. template CUTE_HOST_DEVICE constexpr auto find(tuple const& t) noexcept { - return detail::findt(t); + return detail::findt(t); } -#endif // CUTLASS_USE_PACKED_TUPLE - // // Custom is_tuple trait simply checks the existence of tuple_size // and assumes std::get(.), std::tuple_element @@ -258,7 +266,7 @@ auto has_tuple_size(...) -> false_type; template struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; -template +template constexpr bool is_tuple_v = cute::is_tuple::value; // @@ -679,8 +687,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) } // end namespace cute -#if ! defined(CUTLASS_USE_PACKED_TUPLE) - namespace CUTE_STL_NAMESPACE { @@ -694,22 +700,8 @@ struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> - : CUTE_STL_NAMESPACE::tuple_element> -{}; - } // end namespace CUTE_STL_NAMESPACE -// -// std compatibility -// - #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { @@ -732,17 +724,5 @@ struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> - : CUTE_STL_NAMESPACE::tuple_element> -{}; - } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD - -#endif // CUTLASS_USE_PACKED_TUPLE diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index 44001b6d..b8ac5f0d 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -73,17 +73,6 @@ struct tuple_element> using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; }; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> -{ - using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; -}; - } // end namespace std #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD @@ -109,16 +98,5 @@ struct tuple_element> using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; }; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> -{ - using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; -}; - } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index a5bef3ec..557e1103 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -330,7 +330,7 @@ ceil_div(IntTupleA const& a, IntTupleB const& b) constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); } else { // tuple int - auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), [] (auto const& init, auto const& ai) { return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai)); }); @@ -390,7 +390,7 @@ shape_div(IntTupleA const& a, IntTupleB const& b) static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); } else { // tuple int - auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), [] (auto const& init, auto const& ai) { return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); }); diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index c1a275c9..adf460bb 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1044,7 +1044,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, auto result_shape_0 = take<0,R-1>(lhs_shape); // Mod out the rhs_shape from the lhs_shape - auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), + auto [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), [] (auto const& init, auto const& si) { return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); }); @@ -1058,7 +1058,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, auto result_stride_0 = take<0,R-1>(lhs_stride); // Divide out the rhs_stride from the lhs_shape - auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), + auto [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), [] (auto const& init, auto const& di) { return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); }); @@ -1067,7 +1067,7 @@ composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); // Mod out the rhs_shape from the lhs_shape - auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), + auto [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), [] (auto const& init, auto const& si) { return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); }); diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 32163072..3c2c23cc 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -508,16 +508,6 @@ struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> - : CUTE_STL_NAMESPACE::tuple_element> -{}; - } // end namespace CUTE_STL_NAMESPACE #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD @@ -542,15 +532,5 @@ struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; -template -struct tuple_size> - : CUTE_STL_NAMESPACE::integral_constant -{}; - -template -struct tuple_element> - : CUTE_STL_NAMESPACE::tuple_element> -{}; - } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index c2e7456e..485c07d5 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -84,7 +84,6 @@ using CUTE_STL_NAMESPACE::uint16_t; using CUTE_STL_NAMESPACE::uint32_t; using CUTE_STL_NAMESPACE::uint64_t; using cutlass::uint128_t; - template struct uint_bit; template <> struct uint_bit< 1> { using type = uint1_t; }; template <> struct uint_bit< 2> { using type = uint2_t; }; @@ -95,7 +94,6 @@ template <> struct uint_bit< 16> { using type = uint16_t; }; template <> struct uint_bit< 32> { using type = uint32_t; }; template <> struct uint_bit< 64> { using type = uint64_t; }; template <> struct uint_bit<128> { using type = cutlass::uint128_t; }; - template using uint_bit_t = typename uint_bit::type; diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp index 5218ba37..be22ab37 100644 --- a/include/cute/tensor_impl.hpp +++ b/include/cute/tensor_impl.hpp @@ -235,7 +235,7 @@ struct Tensor decltype(auto) operator()(Coord const& coord) { if constexpr (has_underscore::value) { - auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + auto [sliced_layout,offset] = slice_and_offset(coord, layout()); return make_tensor(data() + offset, sliced_layout); } else { return data()[layout()(coord)]; @@ -249,7 +249,7 @@ struct Tensor decltype(auto) operator()(Coord const& coord) const { if constexpr (has_underscore::value) { - auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + auto [sliced_layout,offset] = slice_and_offset(coord, layout()); return make_tensor(data() + offset, sliced_layout); } else { return data()[layout()(coord)]; diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h index 10b6af8a..8dea5800 100644 --- a/include/cutlass/arch/config.h +++ b/include/cutlass/arch/config.h @@ -102,6 +102,7 @@ #if (!defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM100_ALL)) #define CUTLASS_ARCH_MMA_SM100A_ENABLED 1 #endif + #endif #endif diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index 766c2223..707e1d75 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -38,10 +38,14 @@ #include "cutlass/cutlass.h" #ifndef CUDA_CTA_RECONFIG_ACTIVATED - #if (__CUDACC_VER_MAJOR__ >= 12 && \ - defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #if defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 12 && ( \ + (__CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) \ + || (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) \ + ) #define CUDA_CTA_RECONFIG_ACTIVATED 1 #endif + + #endif namespace cutlass { diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp index 74b0e011..27eed799 100644 --- a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -106,7 +106,6 @@ struct CollectiveConv< using ProblemShape = ConvProblemShape; - // TODO: move pipeline mode tiling into the collective setup phase instead static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index b5d78ce5..fbef858a 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -255,23 +255,27 @@ public: CUTLASS_HOST_DEVICE int64_t activation_size() const { - return (N * H * W * C); + return static_cast(N) * static_cast(H) * + static_cast(W) * static_cast(C); } /// Returns filter size in number of elements CUTLASS_HOST_DEVICE int64_t filter_size() const { - return (K * R * S * C / groups); + return static_cast(K) * static_cast(R) * + static_cast(S) * static_cast(C) / + static_cast(groups); } /// Returns output size in number of elements CUTLASS_HOST_DEVICE int64_t output_size() const { - return (N * P * Q * K); + return static_cast(N) * static_cast(P) * + static_cast(Q) * static_cast(K); } - + /// Returns padding as Tensor4DCoord CUTLASS_HOST_DEVICE cutlass::Tensor4DCoord padding() const { diff --git a/include/cutlass/conv/conv3d_problem_size.h b/include/cutlass/conv/conv3d_problem_size.h index a7e08361..48bf056e 100644 --- a/include/cutlass/conv/conv3d_problem_size.h +++ b/include/cutlass/conv/conv3d_problem_size.h @@ -285,21 +285,27 @@ public: CUTLASS_HOST_DEVICE int64_t activation_size() const { - return (N * D * H * W * C); + return static_cast(N) * static_cast(D) * + static_cast(H) * static_cast(W) * + static_cast(C); } /// Returns filter size in number of elements CUTLASS_HOST_DEVICE int64_t filter_size() const { - return (K * T * R * S * C); + return static_cast(K) * static_cast(T) * + static_cast(R) * static_cast(S) * + static_cast(C); } /// Returns output size in number of elements CUTLASS_HOST_DEVICE int64_t output_size() const { - return (N * Z * P * Q * K); + return static_cast(N) * static_cast(Z) * + static_cast(P) * static_cast(Q) * + static_cast(K); } /// Returns padding as Coord3D diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index f166afc8..a9aae87b 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -114,6 +114,33 @@ public: return status; } + // Check that tensor sizes don't exceed maximum supported size + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.activation_size() * sizeof(ElementA) >= + (1ull << 31) || + args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) || + args.problem_size.output_size() * sizeof(ElementC) >= (1ull << 31)) { + return Status::kErrorInvalidProblem; + } + } + else if (kConvolutionalOperator == conv::Operator::kDgrad || + kConvolutionalOperator == conv::Operator::kDeconv) { + if (args.problem_size.activation_size() * sizeof(ElementC) >= + (1ull << 31) || + args.problem_size.filter_size() * sizeof(ElementB) >= (1ull << 31) || + args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) { + return Status::kErrorInvalidProblem; + } + } + else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.activation_size() * sizeof(ElementB) >= + (1ull << 31) || + args.problem_size.filter_size() * sizeof(ElementC) >= (1ull << 31) || + args.problem_size.output_size() * sizeof(ElementA) >= (1ull << 31)) { + return Status::kErrorInvalidProblem; + } + } + // check group conv constraint if (args.problem_size.groups != 1) { if (kGroupMode == conv::GroupMode::kNone) { diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index c9cd4421..98e77893 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -104,7 +104,7 @@ namespace cutlass { #else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) -#if (__CUDACC_VER_MAJOR__ >= 13) +#if (__CUDACC_VER_MAJOR__ > 12) #define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ template \ @@ -142,7 +142,7 @@ namespace cutlass { return reinterpret_cast(pfn)(args...); \ } -#endif // (__CUDACC_VERSION__ >= 12.5) +#endif // (__CUDACC_VER_MAJOR__ > 12) #endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp index c9042351..aed30bee 100644 --- a/include/cutlass/detail/collective/mixed_input_utils.hpp +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -69,7 +69,7 @@ struct LayoutAwareConvertImpl { auto&& src_vm = cute::recast(src); auto&& dst_vm = cute::recast(dst); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i -constexpr auto -sm100_compute_tile_shape_or_override() { - using namespace cute; - - if constexpr (cute::is_same_v && - cute::is_same_v && - size<1>(CtaTileShape_MNK{}) == 256) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int DpFull = 32; - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - // Note: - // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. - // This is a general workable epi_tile_N which does not promise best perf. - return make_tile(Int{}, Int<128>{}); - } - else if constexpr (cute::is_same_v) { - constexpr int CtaM = size<0>(CtaTileShape_MNK{}); - constexpr int CtaN = size<1>(CtaTileShape_MNK{}); - constexpr int WarpM = size<0>(TmemWarpShape_MN{}); - constexpr int WarpN = size<1>(TmemWarpShape_MN{}); - constexpr bool DisableSource = is_void_v; - constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); - - constexpr int DpFull = 32; // tmem datapaths in 1 subpartition - constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load - constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf - // Epilogues w/o residual load are less sensitive to smem allocation - // Target a fixed amount of compute per epilogue iteration - if (DisableSource) { - if (MaxBits == 4) { - // Make epilogue tile larger to reduce the epilogue iterations. - // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. - constexpr int ComputeElts = 8192; - return ComputeElts / M; - } - constexpr int ComputeElts = 4096; - return ComputeElts / M; - } - // Epilogues w/ residual load are more sensitive to smem allocation - // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize - else { - if (MaxBits == 32) { - return (CtaM > 64 && CtaN <= 128) ? 16 : 32; - } - // Per-column scaling is high register pressure, reduce tile to prevent spills - else if (FusionOp::IsPerColScaleSupported) { - return 32; - } - else if (MaxBits == 16) { - return (CtaN <= 128) ? 32 : 64; - } - else { - return 64; - } - } - }(); - constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN - : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type - : 128 / sizeof_bits_v * WarpN; - constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); - static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); - - // stride by tmem warp layout and return a by-mode tiler - auto tile_m = Layout>{}; - auto tile_n = Layout,Int< WarpN>>, - Stride,Int>>{}; - - return make_tile(tile_m, coalesce(tile_n)); - } - else if constexpr (cute::is_tuple::value) { - EpilogueTileType epi_tile; - constexpr int M = size<0>(shape(epi_tile)); - constexpr int N = size<1>(shape(epi_tile)); - - static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); - static_assert(TmemWarpShape_MN{} == Shape<_2,_2>{} && (M == 32 || M == 64) || - TmemWarpShape_MN{} == Shape<_4,_1>{} && (M == 64 || M == 128), "Unsupported tile shape"); - static_assert(N % 8 == 0, "Unsupported tile shape"); - - return epi_tile; - } - else { - static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); - } -} - -template -static constexpr bool IsPtrArrayDispatchPolicy = - cute::is_same_v || - cute::is_same_v; - - -template < - class CtaTileShape_MNK, - class EpilogueTile_MN, - class ElementC, - class ElementD, - class Schedule -> -constexpr auto -sm100_get_tma_dispatch_policy() { - using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{}))); - constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); - constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup; - // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation - constexpr bool ReuseSmem = sizeof_bits_v > 8; - constexpr bool DelayTmaStore = false; - constexpr int StagesD = cute::min(EpiTiles, 2); - constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) - : cute::min(EpiTiles, 4); - - if constexpr (detail::IsPtrArrayDispatchPolicy) { - return Sm100PtrArrayTmaWarpSpecialized{}; - } - else - { - return Sm100TmaWarpSpecialized{}; - } -} - /* * Returns the TMEM_LOAD copy op to be used for the epilogue * Returned TMEM_LOAD op is such that the thread-value ownership matches the widest available @@ -344,10 +208,10 @@ sm100_get_tmem_load_op() { // For complex TF32 kernels else if constexpr (sizeof_bits_v == 64 && sizeof_bits_v == 64) { if constexpr (num_dp == 16) { - return TMEM::op_repeater(); + return TMEM::op_repeater(); } else { - return TMEM::op_repeater(); + return TMEM::op_repeater(); } } // For narrow precision output @@ -376,7 +240,6 @@ sm100_get_smem_store_op() { static_assert(is_m_major || is_n_major, "Unsupported gmem layout"); // Check for TMEM_LOAD layouts that match the thread-value ownership pattern of stmatrix - // TODO: check copy vectorization instead! constexpr bool use_stmatrix_m8n8_4x = (sizeof_bits_v == 32 && sizeof_bits_v == 32 && is_n_major && ( cute::is_same_v || @@ -451,22 +314,7 @@ sm100_get_smem_store_op() { } } -template -constexpr auto -sm100_get_register_transform_op() { - using namespace cute; - [[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{}); - [[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{}); - static_assert(is_m_major || is_n_major, "Unsupported gmem layout"); - - if constexpr (sizeof_bits_v == 4 && is_m_major) { - return SM50_Shuffle_U32_2x2Trans_XOR1{}; - } - else { - return AutoVectorizingCopyWithAssumedAlignment<128>{}; - } -} // Selects the largest vectorized smem load atom available // subject to constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership @@ -503,30 +351,6 @@ sm100_get_smem_load_op() { } } -template -constexpr auto -sm100_get_gmem_load_op() { - if constexpr (detail::is_im2col_mode) { - return SM90_TMA_LOAD_IM2COL{}; - } - else { - - return SM90_TMA_LOAD{}; - } -} - -template -constexpr auto -sm100_get_gmem_store_op() { - if constexpr (detail::is_im2col_mode) { - return SM90_TMA_STORE_IM2COL{}; - } - else { - - return SM90_TMA_STORE{}; - } -} - // aux fusion callbacks builder for sm100 tma epilogue template < int StagesC, @@ -622,9 +446,9 @@ struct CallbacksBuilder< // the fusion operation performed and the dispatch policy to use. template < class OpClass, - class CtaTileShape_MNK, + class MmaTileShape_MNK, + class ClusterShape_MNK, class EpilogueTileType, - class TmemWarpShape_MN, class ElementAccumulator, class ElementCompute, class ElementC_, @@ -637,62 +461,237 @@ template < class FusionOpOrCallbacks > struct Sm100TmaBuilderImpl { +private: + static constexpr bool Is1SmMma = is_base_of_v; + static constexpr bool Is2SmMma = is_base_of_v; + static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); + static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); + // Passing void C disables source load + smem allocation - using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages - using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; - - using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; - using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; - - using CopyOpS2G = decltype(detail::sm100_get_gmem_store_op()); - using CopyOpG2S = decltype(detail::sm100_get_gmem_load_op()); - - using FusionOp = conditional_t, - FusionOpOrCallbacks, epilogue::fusion::FusionOperation>; - - using EpilogueTile_MN = decltype(detail::sm100_compute_tile_shape_or_override< - OpClass, CtaTileShape_MNK, EpilogueTileType, TmemWarpShape_MN, - ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, FusionOp>()); - using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{}))); - using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); - using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + static constexpr bool DisableSource = cute::is_void_v; + using ElementC = cute::conditional_t; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t; using InternalSmemElementC = typename cutlass::detail::get_unpacked_element_type::type; using InternalSmemElementD = typename cutlass::detail::get_unpacked_element_type::type; - using DispatchPolicy = decltype(detail::sm100_get_tma_dispatch_policy< - CtaTileShape_MNK, EpilogueTile_MN, ElementC_, ElementD, Schedule>()); + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination - using FusionCallbacks = - typename CallbacksBuilder< - DispatchPolicy, - FusionOpOrCallbacks, - CtaTileShape_MNK, - EpilogueTile_MN, - ElementAccumulator, - AccLoadOp - >::Callbacks; + static constexpr bool IsTaggedFusionOp = is_base_of_v; + using FusionOp = conditional_t; + static constexpr auto + cta_tile_shape() { + if constexpr (Is2SmMma) { // 2x1 threadblock shape + auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{}; + auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode + return make_shape(cta_tile_m, mma_tile_n, mma_tile_k); + } + else { // 1x1 threadblock shape + return MmaTileShape_MNK{}; + } + } + using CtaTileShape_MNK = decltype(cta_tile_shape()); + + static constexpr auto + tmem_warps() { + if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { + return Shape<_2,_2>{}; + } + else { + return Shape<_4,_1>{}; + } + } + using TmemWarpShape_MN = decltype(tmem_warps()); + + // Attempts to compute a reasonably performant epilogue tile or allows the user to provide one. + static constexpr auto + epilogue_tile() { + using namespace cute; + + if constexpr (is_same_v && + is_same_v && + size<1>(CtaTileShape_MNK{}) == 256) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int DpFull = 32; + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + // Note: + // Set Epi_Tile_N to 128 support OverlappingAccum for the largest tile. + // This is a general workable epi_tile_N which does not promise best perf. + return make_tile(Int{}, Int<128>{}); + } + else if constexpr (is_same_v) { + constexpr int CtaM = size<0>(CtaTileShape_MNK{}); + constexpr int CtaN = size<1>(CtaTileShape_MNK{}); + constexpr int WarpM = size<0>(TmemWarpShape_MN{}); + constexpr int WarpN = size<1>(TmemWarpShape_MN{}); + constexpr int MaxBits = cute::max(sizeof_bits_v, sizeof_bits_v); + + constexpr int DpFull = 32; // tmem datapaths in 1 subpartition + constexpr int M = cute::min(CtaM, DpFull * WarpM); // target 32dp tmem load + constexpr int N_perf = [&]() constexpr { // Known subtile sizes tested for perf + // Epilogues w/o residual load are less sensitive to smem allocation + // Target a fixed amount of compute per epilogue iteration + if (DisableSource) { + if (MaxBits == 4) { + // Make epilogue tile larger to reduce the epilogue iterations. + // 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + constexpr int ComputeElts = 8192; + return ComputeElts / M; + } + constexpr int ComputeElts = 4096; + return ComputeElts / M; + } + // Epilogues w/ residual load are more sensitive to smem allocation + // Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + else { + if (MaxBits == 32) { + return (CtaM > 64 && CtaN <= 128) ? 16 : 32; + } + // Per-column scaling is high register pressure, reduce tile to prevent spills + else if (FusionOp::IsPerColScaleSupported) { + return 32; + } + else if (MaxBits == 16) { + return (CtaN <= 128) ? 32 : 64; + } + else { + return 64; + } + } + }(); + constexpr int N_min_C = (DisableSource || detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N_min_D = (detail::is_m_major()) ? 8 * WarpN + : (sizeof_bits_v == 6) ? 128 * WarpN // TMA store only supports SW128B for FP6 data type + : 128 / sizeof_bits_v * WarpN; + constexpr int N = cute::min(CtaN, cute::max(N_perf, N_min_C, N_min_D)); + static_assert(CtaN >= N_min_C && CtaN >= N_min_D, "CTA tile too small"); + + // stride by tmem warp layout and return a by-mode tiler + auto tile_m = Layout>{}; + auto tile_n = Layout,Int< WarpN>>, + Stride,Int>>{}; + + return make_tile(tile_m, coalesce(tile_n)); + } + else { + static_assert(cute::is_tuple::value && not is_layout::value, + "EpilogueTile must be a cute::Tile or cute::Shape"); + + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + static_assert(N % 8 == 0, "Unsupported tile shape"); + + return epi_tile; + } + } + using EpilogueTile_MN = decltype(epilogue_tile()); + + using EpilogueTileShape_MN = decltype(product_each(shape(EpilogueTile_MN{}))); + static constexpr int EpiTiles = size(shape_div(take<0,2>(CtaTileShape_MNK{}), EpilogueTileShape_MN{})); + static constexpr int FragmentSize = size(EpilogueTileShape_MN{}) / NumThreadsPerWarpGroup; + + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTileShape_MN{}, TmemWarpShape_MN{})); + using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< + GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + + static constexpr auto + dispatch_policy() { + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = sizeof_bits_v > 8; + // TMA store delay performs worse with residual loads + constexpr bool DelayTmaStore = is_void_v; + + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + if constexpr (is_same_v || + is_same_v) { + constexpr bool DelayTmaStore_ = false; // TMA store delay complicates tensormap updates for Ptr-Array GEMMs + return Sm100PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm100TmaWarpSpecialized{}; + } + } + + static constexpr auto + fusion_callbacks() { + { + return typename CallbacksBuilder< + decltype(dispatch_policy()), + FusionOpOrCallbacks, + CtaTileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + AccLoadOp + >::Callbacks({},{}); + } + } + + static constexpr auto + gmem_load_op() { + if constexpr (detail::is_im2col_mode) { + return SM90_TMA_LOAD_IM2COL{}; + } + else { + return SM90_TMA_LOAD{}; + } + } + + static constexpr auto + gmem_store_op() { + if constexpr (detail::is_im2col_mode) { + return SM90_TMA_STORE_IM2COL{}; + } + else { + return SM90_TMA_STORE{}; + } + } + + static constexpr auto + register_shuffle_op() { + using namespace cute; + + [[maybe_unused]] constexpr bool is_m_major = cutlass::detail::is_major<0>(GmemStrideTypeD{}); + [[maybe_unused]] constexpr bool is_n_major = cutlass::detail::is_major<1>(GmemStrideTypeD{}); + static_assert(is_m_major || is_n_major, "Unsupported gmem layout"); + + if constexpr (sizeof_bits_v == 4 && is_m_major) { + return SM50_Shuffle_U32_2x2Trans_XOR1{}; + } + else { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } + } + +public: using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< - DispatchPolicy, + decltype(dispatch_policy()), CtaTileShape_MNK, EpilogueTile_MN, ElementC_, // Need to pass void through to expose via GemmUniversal GmemStrideTypeC, ElementD, GmemStrideTypeD, - FusionCallbacks, + decltype(fusion_callbacks()), AccLoadOp, - CopyOpG2S, + decltype(gmem_load_op()), decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm100_get_smem_load_op()), - CopyOpS2G, + decltype(gmem_store_op()), decltype(detail::sm100_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm100_get_smem_store_op()), - decltype(detail::sm100_get_register_transform_op()) + decltype(register_shuffle_op()) >; }; @@ -702,7 +701,8 @@ struct Sm100TmaBuilderImpl { // No smem builder template < - class CtaTileShape_MNK, + class OpClass, + class MmaTileShape_MNK, class ClusterShape_MNK, class EpilogueTileType, class ElementAccumulator, @@ -718,8 +718,8 @@ template < > struct CollectiveBuilder< arch::Sm100, - arch::OpClassTensorOp, - CtaTileShape_MNK, + OpClass, + MmaTileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, @@ -732,11 +732,16 @@ struct CollectiveBuilder< AlignmentD, EpilogueScheduleType, FusionOpOrCallbacks, - cute::enable_if_t || - cute::is_same_v >> { + cute::enable_if_t || + is_base_of_v > +> { +private: + static_assert(cute::sizeof_bits_v != 6, "Output element requires TMA"); - static_assert(cute::is_same_v, "Epilogue subtiling requires smem"); - static_assert(cute::sizeof_bits_v != 4 and cute::sizeof_bits_v != 6, "Output element requires smem"); + static constexpr bool Is1SmMma = is_base_of_v; + static constexpr bool Is2SmMma = is_base_of_v; + static_assert(Is1SmMma ^ Is2SmMma, "unsupported schedule"); + static_assert(not (Is2SmMma && size<0>(ClusterShape_MNK{}) % 2 == 1), "schedule + cluster mismatch"); static constexpr bool DisableSource = cute::is_void_v; using ElementC = cute::conditional_t; // prevents void ref breakages @@ -744,173 +749,110 @@ struct CollectiveBuilder< using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; - using FusionOp = conditional_t, - FusionOpOrCallbacks, epilogue::fusion::FusionOperation>; + static constexpr bool IsTaggedFusionOp = is_base_of_v; + using FusionOp = conditional_t; - // use a 4x2 division to select tmem load shape in order to maintain compatability with both (4,1) and (2,2) layouts - using EpilogueTile = decltype(take<0,2>(CtaTileShape_MNK{})); - using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_2>{})); + static constexpr auto + cta_tile_shape() { + if constexpr (Is2SmMma) { // 2x1 threadblock shape + auto [mma_tile_m, mma_tile_n, mma_tile_k] = MmaTileShape_MNK{}; + auto cta_tile_m = reverse(shape_div(reverse(mma_tile_m), _2{})); // first MmaTile_M/2 elements, preserve multimode + return make_shape(cta_tile_m, mma_tile_n, mma_tile_k); + } + else { // 1x1 threadblock shape + return MmaTileShape_MNK{}; + } + } + using CtaTileShape_MNK = decltype(cta_tile_shape()); + + static constexpr auto + tmem_warps() { + if constexpr (Is2SmMma && size<0>(MmaTileShape_MNK{}) == 128) { + return Shape<_2,_2>{}; + } + else { + return Shape<_4,_1>{}; + } + } + using TmemWarpShape_MN = decltype(tmem_warps()); + + static constexpr auto + epilogue_tile() { + using namespace cute; + if constexpr (not is_same_v) { + static_assert(is_tuple_v, "Shape or Tile"); + return EpilogueTileType{}; + } + else if constexpr (is_same_v) { // perf specialized case + constexpr int EpiM = size<0>(CtaTileShape_MNK{}); + constexpr int EpiN = cute::min(_64{}, size<1>(CtaTileShape_MNK{})); + return Shape, Int>{}; + } + else { + return take<0,2>(CtaTileShape_MNK{}); + } + } + using EpilogueTile = decltype(epilogue_tile()); + + using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, TmemWarpShape_MN{})); using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); + static constexpr int FragmentSize = size(EpilogueTile{}) / NumThreadsPerWarpGroup; - using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized; + static constexpr auto + dispatch_policy() { + if constexpr (is_same_v || + is_same_v) { + return Sm100PtrArrayNoSmemWarpSpecialized{}; + } + else { + return Sm100NoSmemWarpSpecialized{}; + } + } + using DispatchPolicy = decltype(dispatch_policy()); - using AlignmentCType = Int; - using AlignmentDType = Int; + static constexpr auto + fusion_callbacks() { + constexpr thread::ScaleType::Kind ScaleType = + DisableSource ? thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + if constexpr (IsDefaultFusionOp::value && not is_same_v) { + // Legacy codepath using thread::LinearCombination, do not expect this to be stable + return thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementCompute, ScaleType, FusionOp::RoundStyle, ElementC>({}); + } + else { + return typename detail::CallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + CtaTileShape_MNK, + EpilogueTile, + ElementAccumulator, + AccLoadOp + >::Callbacks({},{}); + } + } - static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; - static constexpr thread::ScaleType::Kind ScaleType = DisableSource ? - thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; - - using FusionCallbacks = cute::conditional_t< - IsDefaultFusionOp::value, - // Legacy codepath using thread::LinearCombination, do not expect this to be stable - thread::LinearCombination< - ElementD, 1, ElementAccumulator, ElementCompute, - ScaleType, RoundStyle, ElementC> - , - typename detail::CallbacksBuilder< +public: + using CollectiveOp = + cutlass::epilogue::collective::CollectiveEpilogue< DispatchPolicy, - FusionOpOrCallbacks, - CtaTileShape_MNK, - EpilogueTile, - ElementAccumulator, - AccLoadOp - >::Callbacks - >; - - using CollectiveOp = cute::conditional_t< - cute::is_same_v, - cutlass::epilogue::collective::CollectiveEpilogue< - cutlass::epilogue::Sm100NoSmemWarpSpecialized, EpilogueTile, ElementC_, GmemStrideTypeC, ElementD, GmemStrideTypeD, - FusionCallbacks, + decltype(fusion_callbacks()), AccLoadOp, - AlignmentCType, - AlignmentDType - >, - cutlass::epilogue::collective::CollectiveEpilogue< - cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized, - EpilogueTile, - ElementC_, - GmemStrideTypeC, - ElementD, - GmemStrideTypeD, - FusionCallbacks, - AccLoadOp - > - >; -}; - -// No smem builder for OpClassBlockScaledTensorOp -template < - class CtaTileShape_MNK, - class ClusterShape_MNK, - class EpilogueTileType, - class ElementAccumulator, - class ElementCompute, - class ElementC_, - class GmemLayoutTagC_, - int AlignmentC, - class ElementD, - class GmemLayoutTagD, - int AlignmentD, - class EpilogueScheduleType, - class FusionOp -> -struct CollectiveBuilder< - arch::Sm100, - arch::OpClassBlockScaledTensorOp, - CtaTileShape_MNK, - ClusterShape_MNK, - EpilogueTileType, - ElementAccumulator, - ElementCompute, - ElementC_, - GmemLayoutTagC_, - AlignmentC, - ElementD, - GmemLayoutTagD, - AlignmentD, - EpilogueScheduleType, - FusionOp, - cute::enable_if_t || - cute::is_same_v >> { - - static_assert(cute::sizeof_bits_v != 6, "Output element requires smem"); - - static constexpr bool DisableSource = cute::is_void_v; - using ElementC = cute::conditional_t; // prevents void ref breakages - using GmemLayoutTagC = cute::conditional_t; - static constexpr thread::ScaleType::Kind ScaleType = DisableSource ? - thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; - using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; - using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; - - static_assert(cute::is_tuple::value || cute::is_same_v); - using EpilogueTile = cute::conditional_t, - cute::Shape<_128, _64>, - EpilogueTileType - >; - - using EpilogueWarpTileShape_MN = decltype(shape_div(EpilogueTile{}, Shape<_4,_1>{})); - using AccLoadOp = decltype(detail::sm100_get_tmem_load_op< - GmemStrideTypeD, ElementAccumulator, ElementD, EpilogueWarpTileShape_MN, FusionOp>()); - - using DispatchPolicy = cutlass::epilogue::Sm100NoSmemWarpSpecialized; - - using AlignmentCType = Int; - using AlignmentDType = Int; - - static constexpr FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest; - - static_assert(is_base_of_v, "only support EVT fusions"); - using FusionCallbacks = - typename detail::CallbacksBuilder< - DispatchPolicy, - FusionOp, - CtaTileShape_MNK, - EpilogueTile, - ElementAccumulator, - AccLoadOp - >::Callbacks; - - using CollectiveOp = cute::conditional_t< - cute::is_same_v, - cutlass::epilogue::collective::CollectiveEpilogue< - cutlass::epilogue::Sm100NoSmemWarpSpecialized, - EpilogueTile, - ElementC_, - GmemStrideTypeC, - ElementD, - GmemStrideTypeD, - FusionCallbacks, - AccLoadOp, - AlignmentCType, - AlignmentDType - >, - cutlass::epilogue::collective::CollectiveEpilogue< - cutlass::epilogue::Sm100PtrArrayNoSmemWarpSpecialized, - EpilogueTile, - ElementC_, - GmemStrideTypeC, - ElementD, - GmemStrideTypeD, - FusionCallbacks, - AccLoadOp - > - >; + Int, + Int + >; }; // TMA epilogue builder template < class OpClass, - class CtaTileShape_MNK, // Static CTA tile shape - class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class MmaTileShape_MNK, + class ClusterShape_MNK, class EpilogueTileType, class ElementAccumulator, class ElementCompute, @@ -926,7 +868,7 @@ template < struct CollectiveBuilder< arch::Sm100, OpClass, - CtaTileShape_MNK, + MmaTileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, @@ -940,30 +882,20 @@ struct CollectiveBuilder< EpilogueScheduleType, FusionOp, cute::enable_if_t< - // OpClass - ( cute::is_same_v - || cute::is_same_v - ) && - // Epilogue Schedule Type - ( cute::is_base_of_v || - cute::is_base_of_v - || detail::IsPtrArrayDispatchPolicy - )>> + // Only support TensorOp kernels + not cute::is_same_v && + (cute::is_base_of_v || + cute::is_base_of_v) + > +> { -private: - using TmemWarpShape_MN = cute::conditional_t(CtaTileShape_MNK{}) == 64 && - (cute::is_base_of_v - || cute::is_same_v - ), - Shape<_2,_2>, Shape<_4,_1>>; - public: using CollectiveOp = typename detail::Sm100TmaBuilderImpl< OpClass, - CtaTileShape_MNK, + MmaTileShape_MNK, + ClusterShape_MNK, EpilogueTileType, - TmemWarpShape_MN, ElementAccumulator, ElementCompute, ElementC, @@ -977,11 +909,11 @@ public: >::CollectiveOp; }; -// Auto builder +// Auto epilogue builder for TensorOp kernels template < class OpClass, - class CtaTileShape_MNK, // Static CTA tile shape - class ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + class MmaTileShape_MNK, + class ClusterShape_MNK, class EpilogueTileType, class ElementAccumulator, class ElementCompute, @@ -991,13 +923,12 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class EpilogueScheduleType, class FusionOp > struct CollectiveBuilder< arch::Sm100, OpClass, - CtaTileShape_MNK, + MmaTileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, @@ -1008,30 +939,41 @@ struct CollectiveBuilder< ElementD, GmemLayoutTagD, AlignmentD, - EpilogueScheduleType, + EpilogueScheduleAuto, FusionOp, - cute::enable_if_t< - // OpClass - ( cute::is_same_v - || cute::is_same_v - ) - // Epilogue Schedule Type - && cute::is_same_v> + // only for TensorOp kernels + cute::enable_if_t> > { private: - static_assert(cute::is_same_v, "Don't specify epilogue tile with auto schedule"); - using TmemWarpShape_MN = cute::conditional_t(CtaTileShape_MNK{}) == 64 && - size<0>(ClusterShape_MNK{}) % 2 == 0 - , - Shape<_2,_2>, Shape<_4,_1>>; + static constexpr bool + is_2sm() { + using namespace cute; + constexpr int MmaTileM = size<0>(MmaTileShape_MNK{}); + constexpr int ClusterM = size<0>(ClusterShape_MNK{}); + constexpr bool StaticClusterM = is_static_v(ClusterShape_MNK{}))>; + constexpr bool EvenClusterM = StaticClusterM && ClusterM % 2 == 0; + if constexpr (not EvenClusterM) { + return false; + } + else if constexpr (is_same_v) { + return MmaTileM == 256; + } + else { + return MmaTileM == 256 || MmaTileM == 128; + } + } + using EpilogueSchedule = cute::conditional_t; + public: + static_assert(cute::is_same_v, "Don't specify epilogue tile with auto schedule"); using CollectiveOp = - typename detail::Sm100TmaBuilderImpl< + typename CollectiveBuilder< + arch::Sm100, OpClass, - CtaTileShape_MNK, + MmaTileShape_MNK, + ClusterShape_MNK, EpilogueTileType, - TmemWarpShape_MN, ElementAccumulator, ElementCompute, ElementC, @@ -1040,7 +982,7 @@ public: ElementD, GmemLayoutTagD, AlignmentD, - EpilogueScheduleType, + EpilogueSchedule, FusionOp >::CollectiveOp; }; diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp index c1b06b06..80eea5e2 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_nosmem.hpp @@ -356,24 +356,21 @@ public: } // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) - Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor mC = make_tensor(make_gmem_ptr(ptr_C_l), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); - Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N) - Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N) + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) - Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor cD_epi = flat_divide( cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cD = thread_t2r.partition_D(cD); // (T2R,T2R_M,T2R_N) -> (m,n,l) // 2. Apply element-wise operation and store to gmem // source is needed @@ -410,7 +407,9 @@ template < class ElementD_, class StrideD_, class ThreadEpilogueOp_, - class CopyOpT2R_ + class CopyOpT2R_, + class AlignmentC, + class AlignmentD > class CollectiveEpilogue< Sm100PtrArrayNoSmemWarpSpecialized, @@ -420,7 +419,9 @@ class CollectiveEpilogue< ElementD_, StrideD_, ThreadEpilogueOp_, - CopyOpT2R_ + CopyOpT2R_, + AlignmentC, + AlignmentD > : public detail::Sm100TmaWarpSpecializedAdapter(cta_tile_shape_mnk); // Represent the full output tensor, slice to get the tile this CTA is responsible for - Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) - Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) - Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) - Tensor gC_epi = flat_divide( gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide( gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor mC = make_tensor(make_gmem_ptr(params.ptr_C), problem_shape_mnl, append<3>(params.dC,_0{})); // (M,N,L) + Tensor mD = make_tensor(make_gmem_ptr(params.ptr_D), problem_shape_mnl, append<3>(params.dD,_0{})); // (M,N,L) + Tensor gC = local_tile(mC, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) + Tensor gD = local_tile(mD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) // Partition source and destination tiles according to tmem copy T2R partitioning (tTR_) auto thread_t2r = tiled_t2r.get_slice(threadIdx.x % size(tiled_t2r)); - Tensor tTR_gC = thread_t2r.partition_D(gC_epi); // (T2R,T2R_M,T2R_N) - Tensor tTR_gD = thread_t2r.partition_D(gD_epi); // (T2R,T2R_M,T2R_N) + Tensor tTR_gC = thread_t2r.partition_D(gC); // (T2R,T2R_M,T2R_N) + Tensor tTR_gD = thread_t2r.partition_D(gD); // (T2R,T2R_M,T2R_N) - Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) - Tensor cD_epi = flat_divide( cCD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) + Tensor coordCD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cCD = local_tile(coordCD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor tTR_cCD = thread_t2r.partition_D(cCD); // (T2R,T2R_M,T2R_N) -> (m,n,l) // 2. Apply element-wise operation and store to gmem ThreadEpilogueOp epilogue_op{params.thread}; @@ -587,18 +584,18 @@ public: int thread_idx = threadIdx.x % ThreadCount; - Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) - Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tAcc = accumulators(make_coord(_,_),_0{},_0{}); // (CTA_M,CTA_N) + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) TiledCopy tiled_t2r = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); ThrCopy thread_t2r = tiled_t2r.get_slice(thread_idx); - Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc_epi); // (T2R,T2R_M,T2R_N,EPI_M,EPI_N) constexpr int FragmentSize = size(EpilogueTile{}) / ThreadCount; - Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) - Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) + Tensor coordD = make_identity_tensor(problem_shape_mnl); // (M,N,L) -> (m,n,l) + Tensor cD = local_tile(coordD, cta_tiler, cta_coord_mnl); // (CTA_M,CTA_N) -> (m,n,l) Tensor cD_epi = flat_divide(cD, EpilogueTile{}); - Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) + Tensor tTR_cD = thread_t2r.partition_D(cD_epi); // (T2R,T2R_M,T2R_N) -> (m,n,l) Tensor tTR_rAcc = make_tensor(shape(tTR_cD(_,_,_,_0{},_0{}))); @@ -689,19 +686,22 @@ public: do_acc_release = iter_m == size<3>(tTR_tAcc)-1 && iter_n == 0; } - Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); + Tensor tTR_cCD_mn = tTR_cCD(_,_,_,epi_m,epi_n); cst_callbacks.begin_loop(epi_m, epi_n); - if (is_C_load_needed) { - Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int{}))); - Tensor tTR_gC_frg = recast>(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); - Tensor tTR_rC_frg = recast>(coalesce(tCrC)); + if constexpr (not cute::is_void_v) { + if (is_C_load_needed) { + using CVecType = uint_bit_t>; + Tensor tTR_cC_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclC.compose(Int{}))); - auto pred_fn_C = [&] (auto const&... coords) { - return elem_less(tTR_cC_frag(coords...), problem_shape_mnl); - }; + auto pred_fn_C = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { + return elem_less(tTR_cC_frag(coords...), problem_shape_mnl); + }; - copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg); + Tensor tTR_gC_frg = recast(coalesce(tTR_gC(_,_,_,epi_m,epi_n))); + Tensor tTR_rC_frg = recast(coalesce(tCrC)); + copy_if(pred_fn_C, tTR_gC_frg, tTR_rC_frg); + } } // Copy accumulator tile from tmem to register @@ -733,17 +733,15 @@ public: Tensor tTR_cD_frag = tensor<1>(zipped_divide(coalesce(tTR_cCD_mn), mclD.compose(Int{}))); - - using VecType = uint_bit_t>; - Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); - Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); - auto pred_fn_D = [&] (auto const&... coords) CUTLASS_LAMBDA_FUNC_INLINE { return elem_less(tTR_cD_frag(coords...), problem_shape_mnl); }; - copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg); + using VecType = uint_bit_t>; + Tensor tTR_gD_frg = recast(coalesce(tTR_gD(_,_,_,epi_m,epi_n))); + Tensor tTR_rD_frg = recast(coalesce(tTR_rD)); + copy_if(pred_fn_D, tTR_rD_frg, tTR_gD_frg); } // for epi_m } // for epi_n diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp index f8c5b287..c3893675 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -340,7 +340,7 @@ public: _1{}); } - typename Params::TMA_D tma_store_d; + typename Params::TMA_D tma_store_d{}; if constexpr (is_destination_supported) { ElementD const* ptr_D_first_batch = reinterpret_cast(args.ptr_D); Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index f13a6b6f..83302627 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -287,7 +287,7 @@ public: EpilogueTile{}); } - typename Params::TMA_D tma_store_d; + typename Params::TMA_D tma_store_d{}; if constexpr (is_destination_supported) { Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); tma_store_d = make_tma_copy_C_sm90( diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index be1ff675..bd083c80 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -44,35 +44,30 @@ namespace cutlass::epilogue { // Builder Epilogue Schedules // ////////////////////////////////////////////////////////////////////////////// - +// Pre-Hopper schedules struct PtrArrayDefault {}; struct EpilogueSimtVectorized {}; struct EpiloguePtrArraySimtVectorized {}; +// Hopper direct store schedules struct NoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecialized {}; struct PtrArrayNoSmemWarpSpecializedTransposed {}; +// Hopper TMA schedules struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; - +struct PtrArrayTmaWarpSpecialized { static constexpr int NumEpilogueWarpGroups = 1; }; +struct PtrArrayTmaWarpSpecializedPingpong { static constexpr int NumEpilogueWarpGroups = 2; }; +struct PtrArrayTmaWarpSpecializedCooperative { static constexpr int NumEpilogueWarpGroups = 2; }; +// Blackwell direct store schedules +struct NoSmemWarpSpecialized1Sm {}; +struct NoSmemWarpSpecialized2Sm {}; +struct PtrArrayNoSmemWarpSpecialized1Sm : NoSmemWarpSpecialized1Sm {}; +struct PtrArrayNoSmemWarpSpecialized2Sm : NoSmemWarpSpecialized2Sm {}; +// Blackwell TMA schedules struct TmaWarpSpecialized1Sm {}; struct TmaWarpSpecialized2Sm {}; -struct PtrArrayTmaWarpSpecialized1Sm {}; -struct PtrArrayTmaWarpSpecialized2Sm {}; - -struct PtrArrayTmaWarpSpecializedCooperative { - static constexpr int NumEpilogueWarpGroups = 2; -}; - -// Standard warp specialized epilogue -struct PtrArrayTmaWarpSpecialized { - static constexpr int NumEpilogueWarpGroups = 1; -}; - -// Pingpong kernel epilogue -struct PtrArrayTmaWarpSpecializedPingpong { - static constexpr int NumEpilogueWarpGroups = 2; -}; - +struct PtrArrayTmaWarpSpecialized1Sm : TmaWarpSpecialized1Sm {}; +struct PtrArrayTmaWarpSpecialized2Sm : TmaWarpSpecialized2Sm {}; // DEPRECATED schedules, will be removed in next release struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index fcd9fc56..8cac28f7 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -53,6 +53,7 @@ struct FusionOperation { // metadata types/queries that can be overrided using ElementOutput = void; using ElementCompute = void; + FloatRoundStyle RoundStyle = FloatRoundStyle::round_indeterminate; using ElementSource = void; static constexpr bool IsSourceSupported = false; diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index ad8f5651..7abed263 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -482,7 +482,6 @@ public: /// Note: The below method only when problem_size_K <= 256 for signed int8 gemm /// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is /// above. -/// TODO: Add logic to fallback to the default approach template < /// Data type used to load and store< tensors typename ElementOutput_, diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index a725a889..279c3aa6 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -39,13 +39,8 @@ #include #endif #if !defined(__QNX__) -#include -#if defined(_MSC_VER) && defined(CCCL_VERSION) && CCCL_VERSION >= 2008000 -#include -#else #include #endif -#endif #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/uint128.h" diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 5d3d6fca..ecbcdff2 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -51,18 +51,57 @@ #ifdef _MSC_VER // Provides support for alternate operators such as 'and', 'or', ... #include +#include #endif // _MSC_VER - #if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) # define CUTLASS_ARCH_CREDUX_ENABLED #endif - namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + + CUTLASS_HOST_DEVICE int32_t popcount(int32_t x) { + #if defined(__CUDA_ARCH__) + return __popc(x); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcount(x); + #elif defined(_MSC_VER) + return __popcnt(x); + #else + int32_t count = 0; + while (x) { + count += x & 1; + x >>= 1; + } + return count; + #endif + } + + CUTLASS_HOST_DEVICE int64_t popcount(int64_t x) { + #if defined(__CUDA_ARCH__) + return __popcll(x); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcountll(x); + #elif defined(_MSC_VER) + return __popcnt64(x); + #else + int64_t count = 0; + while (x) { + count += x & 1; + x >>= 1; + } + return count; + #endif + } + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + template struct absolute_value_op { CUTLASS_HOST_DEVICE @@ -609,22 +648,7 @@ struct and_popc_add { CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const { A and_result = a & b; - -#if defined(__CUDA__ARCH__) - int popc_result = __popc(and_result); - - if constexpr (sizeof(A) == sizeof(uint64_t)) { - popc_result += __popc(static_cast(and_result >> 32)); - } - -#else - int popc_result = __builtin_popcount(and_result); - if constexpr (sizeof(A) == sizeof(uint64_t)) { - popc_result += __builtin_popcount(static_cast(and_result >> 32)); - } - -#endif - + int32_t popc_result = detail::popcount(and_result); return C(popc_result) + c; } }; @@ -646,22 +670,7 @@ struct xor_popc_add { CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const { A xor_result = a ^ b; - -#if defined(__CUDA__ARCH__) - int popc_result = __popc(xor_result); - - if constexpr (sizeof(A) == sizeof(uint64_t)) { - popc_result += __popc(static_cast(xor_result >> 32)); - } - -#else - int popc_result = __builtin_popcount(xor_result); - if constexpr (sizeof(A) == sizeof(uint64_t)) { - popc_result += __builtin_popcount(static_cast(xor_result >> 32)); - } - -#endif - + int32_t popc_result = detail::popcount(xor_result); return C(popc_result) + c; } }; @@ -682,22 +691,7 @@ struct or_popc_add { CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const { A or_result = a | b; - -#if defined(__CUDA__ARCH__) - int popc_result = __popc(or_result); - - if constexpr (sizeof(A) == sizeof(uint64_t)) { - popc_result += __popc(static_cast(or_result >> 32)); - } - -#else - int popc_result = __builtin_popcount(or_result); - if constexpr (sizeof(A) == sizeof(uint64_t)) { - popc_result += __builtin_popcount(static_cast(or_result >> 32)); - } - -#endif - + int32_t popc_result = detail::popcount(or_result); return C(popc_result) + c; } }; diff --git a/include/cutlass/gemm/collective/builders/sm100_common.inl b/include/cutlass/gemm/collective/builders/sm100_common.inl index 8e53866a..9f2542b5 100644 --- a/include/cutlass/gemm/collective/builders/sm100_common.inl +++ b/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -567,7 +567,7 @@ sm100_make_trivial_fastFP32_tiled_mma() { } /** - * @brief Check for U4_UNPACK_U8, U6_UNPACK_U8 alignment requirement + * @brief Check for F8F6F4 alignment requirement * * @tparam TileShape_MNK (MmaAtomShape_M, MmaAtomShape_N, TileShape_K) * @tparam ClusterShape_MNK (cluster_M, cluster_N, cluster_K) diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 4209fd87..7736dbee 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -85,7 +85,7 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_co } // Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale. -template +template constexpr int compute_stage_count_with_blockwise_scale(StageCountAutoCarveout stage_count) { constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); @@ -96,7 +96,7 @@ compute_stage_count_with_blockwise_scale(StageCountAutoCarveout cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + cutlass::bits_to_bytes(scale_bits * ScaleMsPerTile) + // scale of tensor A - cutlass::bits_to_bytes(scale_bits * 1); // scale of tensor B + cutlass::bits_to_bytes(scale_bits * ScaleNsPerTile); // scale of tensor B constexpr int stage_bytes = cutlass::round_up(stage_bytes_, alignment) + static_cast(mainloop_pipeline_bytes); @@ -1043,7 +1043,8 @@ template < class TileShape_MNK, class ClusterShape_MNK, class StageCountType, - int ScaleGranularityM_ + int ScaleGranularityM_, + int ScaleGranularityN_ > struct CollectiveBuilder< arch::Sm90, @@ -1058,11 +1059,11 @@ struct CollectiveBuilder< TileShape_MNK, ClusterShape_MNK, StageCountType, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum, + KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum, cute::enable_if_t< not detail::is_use_rmem_A()> > { - using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; static_assert(is_static::value); static_assert(is_static::value); @@ -1090,7 +1091,7 @@ struct CollectiveBuilder< static constexpr bool IsCooperative = cute::is_any_of_v>; + KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>; using AtomLayoutMNK = cute::conditional_t>, Layout>>; @@ -1109,12 +1110,15 @@ struct CollectiveBuilder< static constexpr int KernelSmemCarveout = static_cast(TensorMapStorage); static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape_MNK{}) : ScaleGranularityM_; + static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape_MNK{}) : ScaleGranularityN_; static constexpr int ScaleMsPerTile = size<0>(TileShape_MNK{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape_MNK{}) / ScaleGranularityN; static_assert((size<0>(TileShape_MNK{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert((size<1>(TileShape_MNK{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile, ScaleNsPerTile>(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; using SmemCopyAtomA = void; using SmemCopyAtomB = void; diff --git a/include/cutlass/gemm/collective/fp8_accumulation.hpp b/include/cutlass/gemm/collective/fp8_accumulation.hpp index bd2a0cb2..9dff91a5 100644 --- a/include/cutlass/gemm/collective/fp8_accumulation.hpp +++ b/include/cutlass/gemm/collective/fp8_accumulation.hpp @@ -75,6 +75,15 @@ private: } // `multiply` scale the partial accumulators and `add` to main accumulator (FFMA). + CUTLASS_DEVICE + void scale_core(ElementAccumulator const &scale) { + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scale; + } + } + template < class EngineScale, class LayoutScale> @@ -94,6 +103,31 @@ private: } } + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale_core(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + using TensorScaleA = cute::Tensor; + using TensorScaleB = cute::Tensor; + + static_assert(is_static::value, "ScaleA Layout should be static"); + static_assert(is_static::value, "ScaleB Layout should be static"); + static_assert(is_rmem::value, "ScaleA tensor must be rmem resident."); + static_assert(is_rmem::value, "ScaleB tensor must be rmem resident."); + + static_assert(LayoutAccum{}.shape() == LayoutScaleA{}.shape(), "Accumulator and scaleA must have same shape."); + static_assert(LayoutAccum{}.shape() == LayoutScaleB{}.shape(), "Accumulator and scaleB must have same shape."); + + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i) * scaleA(i) * scaleB(i); + } + } + public: CUTLASS_DEVICE GmmaFP8Accumulation( @@ -152,6 +186,16 @@ public: // /// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void scale_if_needed(ElementAccumulator const &scale) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scale); + mma_count_ = 0; + } + } + template < class EngineScale, class LayoutScale> @@ -165,7 +209,29 @@ public: } } + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale_if_needed(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + scale_core(scaleA, scaleB); + mma_count_ = 0; + } + } + /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void scale_residue_if_needed(ElementAccumulator const &scale) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scale); + } + } + template < class EngineScale, class LayoutScale> @@ -175,6 +241,18 @@ public: scale_core(scale); } } + + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale_residue_if_needed(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + scale_core(scaleA, scaleB); + } + } }; } // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp index 65718878..fec954a5 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -30,8 +30,6 @@ **************************************************************************************************/ - - #pragma once #include "cutlass/cutlass.h" @@ -288,23 +286,23 @@ struct CollectiveMma< using TensorStorage = typename SharedStorage::TensorStorage; using PipelineStorage = typename SharedStorage::PipelineStorage; + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly static constexpr uint32_t SFTransactionBytes = cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutSFB{})) * cute::sizeof_bits_v); - // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly static constexpr uint32_t ABTmaTransactionBytes = cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); static constexpr uint32_t TmaTransactionBytes = ABTmaTransactionBytes + SFTransactionBytes; - template + template struct TmemStorage { AccTensor accumulators; SfaTensor tCtSFA; SfbTensor tCtSFB; }; - template< + template < class KTileCount, class GTensorPartitionedA, class GTensorPartitionedB, class STensorA, class STensorB, @@ -348,7 +346,8 @@ struct CollectiveMma< , mcast_mask_sfa(mcast_mask_sfa_), mcast_mask_sfb(mcast_mask_sfb_) {} }; - template< + template < + class TiledMma, class FragmentA, class FragmentB, class FragmentSFA, class FragmentSFB, class SFATiledCopy, class SmemFrgSFA, class TmemFrgSFA, @@ -496,6 +495,7 @@ struct CollectiveMma< Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + // Cluster layout for TMA construction auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); @@ -505,7 +505,7 @@ struct CollectiveMma< // Cluster layout for TMA construction of SFB auto cluster_layout_sfb_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMMA_SF::AtomThrID{})); - auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); + auto cluster_layout_sfb_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMMA_SF::AtomThrID{})); typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( GmemTiledCopyA{}, @@ -649,7 +649,7 @@ struct CollectiveMma< return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); } - template + template CUTLASS_DEVICE static auto init_tmem_tensors(EpilogueTile epi_tile) { @@ -660,7 +660,7 @@ struct CollectiveMma< tiled_mma, acc_shape, EpilogueTile{}); Tensor tCtSFA = make_tensor(shape(SmemLayoutAtomSFA{})); Tensor tCtSFB = make_tensor(shape(SmemLayoutAtomSFB{})); - + TmemStorage tmem_storage; tmem_storage.accumulators = accumulators; tmem_storage.tCtSFA = tCtSFA; @@ -669,10 +669,10 @@ struct CollectiveMma< return tmem_storage; } - template + template CUTLASS_DEVICE static void - set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { tmem_storage.accumulators.data() = tmem_base_addr; tmem_storage.tCtSFA.data() = tmem_storage.accumulators.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.accumulators); tmem_storage.tCtSFB.data() = tmem_storage.tCtSFA.data().get() + cutlass::detail::find_tmem_tensor_col_offset(tmem_storage.tCtSFA); @@ -751,7 +751,6 @@ struct CollectiveMma< Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_SFB.begin()), SmemLayoutSFB{}); // Define the CTA-in-cluster Layout and Coord - Layout cta_layout_mnk = make_layout(cluster_shape_); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); @@ -785,13 +784,11 @@ struct CollectiveMma< uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); - LoadParams load_params { + return LoadParams{ size<3>(gA_mkl), // for scheduler tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values - mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb // multicast masks - }; - return load_params; + mcast_mask_a, mcast_mask_b, mcast_mask_sfa, mcast_mask_sfb}; // multicast masks } /// Set up the data needed by this collective for mma compute. @@ -802,8 +799,8 @@ struct CollectiveMma< TensorStorage& shared_tensors) const { // Allocate "fragments/descriptors" for A and B matrices - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) // Allocate "fragments/descriptors" for A and B matrices Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) @@ -854,17 +851,12 @@ struct CollectiveMma< tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; } - MmaParams< - decltype(tCrA), decltype(tCrB), decltype(tCtSFA), decltype(tCtSFB), - decltype(tiled_copy_s2t_SFA), decltype(thr_tCsSFA_compact_s2t), decltype(thr_tCtSFA_compact_s2t), - decltype(tiled_copy_s2t_SFB), decltype(thr_tCsSFB_compact_s2t), decltype(thr_tCtSFB_compact_s2t) - > mma_params { + + return MmaParams{ tiled_mma, tCrA, tCrB, tCtSFA, tCtSFB, tiled_copy_s2t_SFA, thr_tCsSFA_compact_s2t, thr_tCtSFA_compact_s2t, - tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t - }; - return mma_params; + tiled_copy_s2t_SFB, thr_tCsSFB_compact_s2t, thr_tCtSFB_compact_s2t}; } /// Perform a collective-scoped matrix multiply-accumulate @@ -983,52 +975,12 @@ struct CollectiveMma< uint32_t skip_wait = k_tile_count <= 0; auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + bool is_first_iter = true; // // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; - if (k_tile_count > 0) { // first iteraion - // WAIT on mainloop_pipe_consumer_state until its data are available - // (phase bit flips from mainloop_pipe_consumer_state.phase() value) - mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); - - // Compute on k_tile - int read_stage = mainloop_pipe_consumer_state.index(); - // Save current mainlop pipeline read state - auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; - - // Advance mainloop_pipe - ++mainloop_pipe_consumer_state; - --k_tile_count; - skip_wait = k_tile_count <= 0; - // Peek at next iteration - barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); - - if (cute::elect_one_sync()) { - copy(tiled_copy_s2t_SFA, thr_tCsSFA_s2t(_,_,_,_,read_stage), thr_tCtSFA_s2t); - copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); - } - - if constexpr (IsOverlappingAccum) { - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - } - - // Unroll the K mode manually so we can set scale C to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma.with(tiled_mma.accumulate_, - tCtSFA(_,_,k_block), - tCtSFB_mma(_,_,k_block)), - tCrA(_,_,k_block,read_stage), - tCrB(_,_,k_block,read_stage), - accumulators); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; - } - mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); - } - CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { // WAIT on mainloop_pipe_consumer_state until its data are available @@ -1052,6 +1004,13 @@ struct CollectiveMma< copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage), thr_tCtSFB_s2t); } + if constexpr (IsOverlappingAccum) { + if (is_first_iter) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + is_first_iter = false; + } + } + // Unroll the K mode manually so we can set scale C to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { @@ -1064,6 +1023,7 @@ struct CollectiveMma< accumulators); tiled_mma.accumulate_ = UMMA::ScaleOut::One; } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); } diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp index c7e56250..f1abb1eb 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp @@ -31,7 +31,6 @@ - #pragma once #include "cutlass/cutlass.h" @@ -239,12 +238,12 @@ struct CollectiveMma< cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); - template + template struct TmemStorage { AccTensor accumulators; }; - template< + template < class KTileCount, class GTensorPartitionedA, class GTensorPartitionedB, class STensorA, class STensorB @@ -273,7 +272,10 @@ struct CollectiveMma< , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) {} }; - template + template < + class TiledMma, + class FragmentA, class FragmentB + > struct MmaParams { TiledMma tiled_mma; FragmentA tCrA; @@ -336,7 +338,7 @@ struct CollectiveMma< , runtime_data_type_a_(params.runtime_data_type_a) , runtime_data_type_b_(params.runtime_data_type_b) { if constexpr (IsDynamicCluster) { - const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; @@ -461,7 +463,7 @@ struct CollectiveMma< return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); } - template + template CUTLASS_DEVICE static auto init_tmem_tensors(EpilogueTile epi_tile) { @@ -475,10 +477,10 @@ struct CollectiveMma< return tmem_storage; } - template + template CUTLASS_DEVICE static void - set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { tmem_storage.accumulators.data() = tmem_base_addr; } @@ -535,21 +537,21 @@ struct CollectiveMma< // TMA Multicast Masks uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); - - LoadParams load_params { + + return LoadParams{ shape<3>(gA_mkl), // for scheduler tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values - mcast_mask_a, mcast_mask_b // multicast masks - }; - return load_params; + mcast_mask_a, mcast_mask_b}; // multicast masks } /// Set up the data needed by this collective for mma compute. - template + template CUTLASS_DEVICE auto mma_init( - [[maybe_unused]] TmemStorage tmem_tensors, - TensorStorage& shared_tensors) const { + [[maybe_unused]] TmemStorage tmem_storage, + TensorStorage& shared_tensors) const { + + // Allocate "fragments/descriptors" for A and B matrices Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -558,7 +560,7 @@ struct CollectiveMma< Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); // PIPE TiledMma tiled_mma; @@ -568,11 +570,10 @@ struct CollectiveMma< tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; } - MmaParams mma_params { + + return MmaParams{ tiled_mma, - tCrA, tCrB - }; - return mma_params; + tCrA, tCrB}; } /// Perform a collective-scoped matrix multiply-accumulate @@ -657,6 +658,7 @@ struct CollectiveMma< ) { static_assert(is_tmem::value, "Accumulator must be tmem resident."); static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + auto accumulators = get<0>(accumulators_pair); auto [tiled_mma, tCrA, tCrB] = mma_inputs; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index e06ead97..546bf915 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -58,6 +58,7 @@ template < class ClusterShape, class KernelSchedule, int ScaleGranularityM_, + int ScaleGranularityN_, class TileShape_, class ElementA_, class StrideA_, @@ -73,7 +74,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, + MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, TileShape_, ElementA_, StrideA_, @@ -92,7 +93,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -120,7 +121,9 @@ struct CollectiveMma< static constexpr int NumProducerThreadEvents = 2; static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_; static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -131,6 +134,7 @@ struct CollectiveMma< static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( @@ -144,12 +148,13 @@ struct CollectiveMma< // Block scaling gmem-to-smem copy atom using BlockScaleCopyTypeA = cute::uint_byte_t(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>; + using BlockScaleCopyTypeB = cute::uint_byte_t(sizeof(ElementBlockScale)) * ScaleNsPerTile, 16)>; using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; - using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; // Block scaling smem layout using SmemLayoutScaleA = Layout, Int>>; - using SmemLayoutScaleB = Layout>, Stride<_1>>; // `ScaleNsPerTile` is always 1. + using SmemLayoutScaleB = Layout, Int>>; static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); static_assert(cute::is_base_of::value && @@ -168,7 +173,7 @@ struct CollectiveMma< cute::array_aligned> smem_A; // mxk cute::array_aligned> smem_B; // nxk cute::array_aligned> smem_scale_A; // ScaleMsPerTile x k - cute::array_aligned> smem_scale_B; // 1xk + cute::array_aligned> smem_scale_B; // ScaleNsPerTile x k } tensors; using PipelineStorage = typename MainloopPipeline::SharedStorage; @@ -322,17 +327,17 @@ struct CollectiveMma< Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) // Make the tiled views of scale tensors - auto scaleA_shape = make_shape(get<2>(gA_mkl.shape()), Int{}, get<3>(gA_mkl.shape()), get<4>(gA_mkl.shape())); // (m,ScaleMsPerTile,k,l) - auto scale_dA = make_stride(get<3>(gA_mkl.shape()) * Int{}, Int<1>{}, Int{}, get<2>(gA_mkl.shape()) * get<3>(gA_mkl.shape()) * Int{}); + auto scaleA_shape = make_shape(shape<2>(gA_mkl), Int{}, shape<3>(gA_mkl), shape<4>(gA_mkl)); // (m,ScaleMsPerTile,k,l) + auto scaleB_shape = make_shape(shape<2>(gB_nkl), Int{}, shape<3>(gB_nkl), shape<4>(gB_nkl)); // (n,ScaleNsPerTile,k,l) + auto scale_dA = compact_order(scaleA_shape, Step<_2,_0,_1,_3>{}); + auto scale_dB = compact_order(scaleB_shape, Step<_2,_0,_1,_3>{}); auto scaleA_layout = make_layout(scaleA_shape, scale_dA); - auto scaleB_shape = make_shape(get<2>(gB_nkl.shape()), get<3>(gB_nkl.shape()), get<4>(gB_nkl.shape())); // (n,k,l) - auto scale_dB = make_stride(get<3>(gB_nkl.shape()), Int<1>{}, get<2>(gB_nkl.shape()) * get<3>(gB_nkl.shape())); auto scaleB_layout = make_layout(scaleB_shape, scale_dB); - // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (m,ScaleMsPerTile,k,l) - Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,ScaleNsPerTile,k,l) return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); } @@ -356,13 +361,13 @@ struct CollectiveMma< uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { int lane_predicate = cute::elect_one_sync(); - // Blockscaling: Tma loads for load_input and CpAsync for load_scale if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k) // // Prepare the TMA loads for A and B @@ -388,10 +393,10 @@ struct CollectiveMma< Tensor mScaleB_nkl = get<3>(load_inputs); Tensor gScaleA = mScaleA_mkl(m_coord,_,_,l_coord); // (1,ScaleMsPerTile,k,1) - Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1) + Tensor gScaleB = mScaleB_nkl(n_coord,_,_,l_coord); // (1,ScaleNsPerTile,k,1) TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout>{}, Layout>>{}); // (1,ScaleMsPerTile,1) - TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout>{}, Layout>{}); // (1,1,1) + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout>{}, Layout>>{}); // (1,ScaleNsPerTile,1) ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); @@ -446,7 +451,7 @@ struct CollectiveMma< // Copy scale tensors from global memory to shared memory copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); - copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage)); + copy(scale_copy_b, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage)); pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); ++k_tile_iter; @@ -508,7 +513,11 @@ struct CollectiveMma< Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, Stride, _0, Int> >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) - Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k) + Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + Layout< + Shape, Shape, Int>, Int>, + Stride<_0, Stride<_0, _1>, Int> + >{}); // (m,(ScaleGranularityN,ScaleNsPerTile),k) // // Define C accumulators and A/B partitioning @@ -531,7 +540,8 @@ struct CollectiveMma< TiledMma tiled_mma; auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); - Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + Tensor tCsScaleBViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) @@ -557,11 +567,8 @@ struct CollectiveMma< PipelineState smem_pipe_release = smem_pipe_read; // Per block scale values for operand A and B - using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout. - using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above - - Tensor tCrScaleAViewAsC = make_tensor(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N) - ElementBlockScale scale_b; + Tensor tCrScaleAViewAsC = make_tensor_like(tCsScaleAViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N) + Tensor tCrScaleBViewAsC = make_tensor_like(tCsScaleBViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N) // Prologue GMMAs int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); @@ -583,21 +590,26 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); - // Load per block scale values from shared memory to registers. - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + // Load per block scale values from shared memory to registers + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0]; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; } } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } warpgroup_arrive(); // Unroll the K mode manually to set scale D to 1 @@ -609,8 +621,20 @@ struct CollectiveMma< } warpgroup_commit_batch(); - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; + accumulation.scale_if_needed(scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_if_needed(tCrScaleAViewAsC); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_if_needed(tCrScaleBViewAsC); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + } ++smem_pipe_read; } @@ -632,21 +656,26 @@ struct CollectiveMma< int read_stage = smem_pipe_read.index(); - // Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N) - scale_b = sScaleB[read_stage]; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { - tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{})); + // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; } - if constexpr (ScaleMsPerTile == 1) { - static_assert(size(RegLayoutScaleAEssential{}) == 1); - tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`. - } else { + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0]; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) { + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; } } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } if (accumulation.prepare_if_needed()) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; @@ -667,8 +696,20 @@ struct CollectiveMma< warpgroup_wait(); warpgroup_fence_operand(accumulation()); - // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` - accumulation.scale_if_needed(tCrScaleAViewAsC); + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; + accumulation.scale_if_needed(scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_if_needed(tCrScaleAViewAsC); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_if_needed(tCrScaleBViewAsC); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + } pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it @@ -677,7 +718,19 @@ struct CollectiveMma< ++smem_pipe_release; } - accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; + accumulation.scale_residue_if_needed(scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrScaleBViewAsC); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + } warpgroup_fence_operand(accumulation()); } diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 155d023d..8747f48b 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -117,7 +117,11 @@ struct KernelPtrArrayTmaWarpSpecializedPingpong { }; // FP8 related policies (including Blocked Scaled Accumulation) template< - int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M. + // `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value + // `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is + // `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N. + int ScaleGranularityM = 0, + int ScaleGranularityN = 0 > struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { }; @@ -302,12 +306,16 @@ template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, class KernelSchedule = KernelTmaWarpSpecialized, - int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, while zero-value `ScaleGranularityM` indicates that scaling granularity is `size<0>(TileShape_MNK{})` along M. + // `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value + // `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is + // `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N. + int ScaleGranularityM = 0, + int ScaleGranularityN = 0 > struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8 : MainloopSm90TmaGmmaWarpSpecialized { static_assert( - cute::is_same_v>, + cute::is_same_v>, "KernelSchedule must be one of the warp specialized policies"); }; diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h index 84d70212..41165cfd 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -397,8 +397,6 @@ public: // An example of an unneeded threadblock is one that is assigned to compute in the upper // portion of a Rank2K kernel filled with mode kLower. // - // TODO: Consider pushing these checks into ProblemVisitor to avoid spuriously - // returning from `next_tile()`. // // Early exit if threadblock is out of range diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp index 69748c9c..65885b8a 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp @@ -1131,6 +1131,10 @@ public: } } + else { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } } }; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp index 5d03f921..95cc663b 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp @@ -29,8 +29,6 @@ * **************************************************************************************************/ - - #pragma once #include "cutlass/cutlass.h" @@ -564,20 +562,21 @@ public: // Sync deallocation status between MMA warps of peer CTAs arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; - if constexpr(!IsOverlappingAccum) { - if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { - tmem_deallocation_result_barrier.init(NumMMAThreads); + if (WarpCategory::MMA == warp_category) { + if constexpr(!IsOverlappingAccum) { + if (has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } + else { + if (has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); + } + else if (lane_predicate) { + tmem_deallocation_result_barrier.init(NumEpilogueThreads); + } } } - else { - if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { - tmem_deallocation_result_barrier.init(NumEpilogueThreads*2); - } - else if (WarpCategory::MMA == warp_category && lane_predicate) { - tmem_deallocation_result_barrier.init(NumEpilogueThreads); - } - } - // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; @@ -699,7 +698,6 @@ public: epilogue_throttle_barrier.arrive(); if constexpr (IsSchedDynamicPersistent) { - // Whether a new CLC query must be performed. // See comment below where this variable is updated for a description of // why this variable is needed. @@ -738,7 +736,6 @@ public: work_tile_info = next_work_tile_info; } while (work_tile_info.is_valid()); clc_pipeline.producer_tail(clc_pipe_producer_state); - } } @@ -963,7 +960,6 @@ public: epi_load_pipe_consumer_state = load_state_next; epi_store_pipe_producer_state = store_state_next; accumulator_pipe_consumer_state = acc_state_next; - do_tail_store = true; } work_tile_info = next_work_tile_info; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp index 4e1d2930..afadb309 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp @@ -1057,6 +1057,10 @@ public: } } + else { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } } }; diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp index f7be566f..8e503353 100644 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp @@ -783,7 +783,6 @@ private: int L_idx, Split_idx; params_.sk_params_.divmod_splits_(L_idx, Split_idx, work_tile_info.L_idx); - // TODO: Modularize the SM90 scheduler to pull out and reuse this redundant code int additional_k_tiles = 0; int split_start_offset = params_.sk_params_.big_units_; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 4482e25d..cfb6912c 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -455,8 +455,9 @@ public: auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) TileScheduler scheduler{params.scheduler}; - auto work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); - + // Declare work_tile_info, then define it in each of warps that use it. + typename TileScheduler::WorkTileInfo work_tile_info; + // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; @@ -474,6 +475,7 @@ public: cluster_wait_fn(); if (warp_group_role == WarpGroupRole::Producer) { + work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); cutlass::arch::warpgroup_reg_dealloc(); // Mainloop Producer Warp @@ -578,6 +580,7 @@ public: } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + work_tile_info = scheduler.initial_work_tile_info(ClusterShape{}); cutlass::arch::warpgroup_reg_alloc(); CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index aa599a35..e1579d3f 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -265,7 +265,7 @@ struct PersistentTileSchedulerSm90Params { } // In case the maximum number of clusters that could co-exist on the target device is // already calculated using cudaOccupancyMaxActiveClusters - else if (max_active_clusters != 0) { + else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) { if (raster_order == RasterOrder::AlongN) { launch_grid.y = max_active_clusters * cluster_shape.n(); } @@ -1204,6 +1204,7 @@ struct PersistentTileSchedulerSm90StreamKParams { KernelHardwareInfo new_hw_info; new_hw_info.device_id = hw_info.device_id; new_hw_info.sm_count = hw_info.sm_count; + new_hw_info.max_active_clusters = hw_info.max_active_clusters; if (new_hw_info.sm_count <= 0) { CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); @@ -1787,7 +1788,7 @@ struct PersistentTileSchedulerSm90GroupParams { } // In case the maximum number of clusters that could co-exist on the target device is // already calculated using cudaOccupancyMaxActiveClusters - else if (max_active_clusters != 0) { + else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) { if (raster_order == RasterOrder::AlongN) { launch_grid.y = max_active_clusters * cluster_shape.n(); } @@ -2499,6 +2500,7 @@ struct PersistentTileSchedulerSm100GroupParams { bool is_static_cluster_shape = false) { int const sm_count = hw_info.sm_count; + int const max_active_clusters = hw_info.max_active_clusters; // Round up to nearest multiple of swizzle_size along each mode auto log_swizzle_size = get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle_size); @@ -2542,6 +2544,18 @@ struct PersistentTileSchedulerSm100GroupParams { launch_grid.x = possibly_truncate(sm_count, problem_blocks_total); } } + // In case the maximum number of clusters that could co-exist on the target device is + // already calculated using cudaOccupancyMaxActiveClusters + else if (max_active_clusters != 0 && max_active_clusters * cluster_size <= sm_count) { + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = max_active_clusters * cluster_shape.n(); + } + else { + launch_grid.x = max_active_clusters * cluster_shape.m(); + } + CUTLASS_TRACE_HOST("get_grid_shape(): Proposed GridDims by the scheduler using cudaOccupancyMaxActiveClusters = " + "(" << launch_grid.x << ", " << launch_grid.y << ", " << launch_grid.z << ")\n"); + } else { constexpr int max_sm_per_gpc = 20; int cta_per_device = get_max_cta_occupancy(max_sm_per_gpc, cluster_shape, sm_count); diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h index f6cc735a..0d1da845 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h @@ -142,7 +142,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< "Shape of warp-level Mma must be divisible by operator shape."); // Shape of one individual LDS.128 - // TODO: 32 and 4 are hardcoded, 32-by-4 is logical shape using LdsShape = layout::PitchLinearShape< 32, 4 @@ -458,7 +457,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< "Shape of warp-level Mma must be divisible by operator shape."); // Shape of one individual LDS - // TODO: remove hardcoded 32 and 4 using LdsShape = layout::PitchLinearShape< 32, 4 diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h index d53d6dfd..a5370ff8 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h @@ -995,7 +995,6 @@ public: CUTLASS_DEVICE MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { - // TODO: fix this if it becomes an issue during warp it reset add_tile_offset(tile_offset); return *this; diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 91e4a9ef..faf64275 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -41,7 +41,6 @@ #pragma once #include - #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/layout/pitch_linear.h" diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index b0c616a7..b44264bb 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -82,7 +82,7 @@ struct get_unpacked_element_type { #include "cutlass/tfloat32.h" #include "cutlass/float8.h" #include "cutlass/uint128.h" -#include "cutlass/exmy_base.h" -#include "cutlass/float_subbyte.h" +#include "cutlass/exmy_base.h" +#include "cutlass/float_subbyte.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/media/docs/blackwell_cluster_launch_control.md b/media/docs/blackwell_cluster_launch_control.md index fe13b960..faebb900 100644 --- a/media/docs/blackwell_cluster_launch_control.md +++ b/media/docs/blackwell_cluster_launch_control.md @@ -2,9 +2,9 @@ ## Overview -A GEMM workload usually consists of three phases: prologue, mainloop and epilogue. Each available SM will process multiple output tiles in series if the number of output tiles are much more than the number of available SMs, completely exposing the overhead of prologue and epilogue. +A GEMM workload usually consists of three phases: prologue, mainloop and epilogue. Each SM will process multiple output tiles in series if the number of output tiles are much more than the number of SMs, completely exposing the overhead of prologue and epilogue. -Consider a GEMM that has `20x20x1` output tiles, running on a GPU with `100` SMs. Only `80` out of the `100` SMs are available. Assume cluster shape is `1x1x1`. The following diagram shows how the schedule would look like for such a kernel. +Consider a GEMM that has `20x20x1` output tiles, running on a GPU with `100` SMs. There is another kernel occupying all the resources of `20` SMs so only `80` SMs can be used. Assume cluster shape is `1x1x1`. The following diagram shows how the schedule would look like for such a kernel.

A beautiful sunset

@@ -12,22 +12,22 @@ Consider a GEMM that has `20x20x1` output tiles, running on a GPU with `100` SMs ### Static Scheduler CUTLASS has adopted a software technique named **persistent kernels**. Persistent clusters, or Workers, can stay on the GPU throughout kernel execution and process multiple tiles, hiding prologue and epilogue costs. The tile scheduler statically determines the next output tile to process with zero overhead. -However, static scheduler is susceptible to workload imbalance if some SMs are unavailable. The following diagram illustrates this issue. +However, static scheduler is susceptible to workload imbalance if the resources of some SMs are unavailable. The following diagram illustrates this issue.

A beautiful sunset

### Dynamic Scheduler with Cluster Launch Control -A fundamental limitation of persistent scheduling is that the kernel is unaware of the number of available SMs in real time. Some SMs might be occupied by another kernel and thus be unavailable. This makes it challenging to load-balance work across available SMs. +A fundamental limitation of persistent scheduling is that the number of SMs this kernel can utilize is unknown in real time. Some SMs might be occupied by another kernel and thus their resources are unavailable. This makes it challenging to load-balance work across SMs. Blackwell introduces cluster launch control (CLC) for dynamic scheduling. (See https://docs.nvidia.com/cuda/parallel-thread-execution). With this feature, the kernel launches a grid containing as many threadblocks as there are output tiles to compute in the kernel -- just like one would in a non-persistent kernel. Here we define `ClcID` to be a coordinate from the 3D grid launched on GPU. Cluster launch control follows the below rules: -1. A `ClcID` will be launched as a Worker when there are available SMs. +1. A `ClcID` will be launched as a Worker when there are available resources. 2. A `ClcID` can be queried by an existing Worker via `clusterlaunchcontrol.try_cancel` instruction. 3. Every `ClcID` is guaranteed to be processed by either (1) or (2). -4. Each Worker is pre-loaded with a `ClcID`, which is the coordinate indicated by `{blockIdx.x, blockIdx.y, blockIdx.z}`. -5. `clusterlaunchcontrol.try_cancel` instruction returns either a success signal with a `ClcID` or a decline signal. The most common reason of a decline is that akk `ClcID`s have been processed. +4. Each worker uses the `{blockIdx.x, blockIdx.y, blockIdx.z}` coordinate as the first output tile to process and uses the CLC query for subsequent processing of output tiles. +5. `clusterlaunchcontrol.try_cancel` instruction returns either a success signal with a `ClcID` or a decline signal. The most common reason of a decline is that all `ClcID`s have been processed. 6. Cluster launch control works on the granularity of clusters. For example, a 2x2 persistent worker cluster's query will consume 2x2 `ClcID`s at once. The following diagram shows how the schedule would look like with cluster launch control. diff --git a/media/docs/blackwell_functionality.md b/media/docs/blackwell_functionality.md index a7c6169f..02488a3b 100644 --- a/media/docs/blackwell_functionality.md +++ b/media/docs/blackwell_functionality.md @@ -285,7 +285,9 @@ Layout, and Dispatch Policy combinations for each row of [Table 1](#legacy_gemm_ | 1/2 SM | Epilogue Dispatch Policy | |--------|------------------------------------------| | 1SM | cutlass::epilogue::TmaWarpSpecialized1Sm | +| 1SM | cutlass::epilogue::NoSmemWarpSpecialized1Sm | | 2SM | cutlass::epilogue::TmaWarpSpecialized2Sm | +| 2SM | cutlass::epilogue::NoSmemWarpSpecialized2Sm | **Table 15: Epilogue PerSmTileShape_MNK** | 1/2 SM | MMA tile Shape | PerSmTileShape_MNK | @@ -442,7 +444,7 @@ PerSmTileShape_MNK should be deduced from the mainloop setup. For example, in ab It means each CTA is doing (256 / 2sm) x 256 x 128 output, so the PerSmTileShape_MNK is 128x256x128. The possible PerSmTileShape_MNK is listed in [Table 15](#epi_persmtileshape) -The epilogue scheduling policy is configurable, and it is common to set `cutlass::epilogue::TmaWarpSpecialized2Sm` +The epilogue scheduling policy is configurable, and it is common to set `cutlass::epilogue::collective::EpilogueScheduleAuto` to allow the epilogue builder to automatically select the appropriate policy. However, it can also be explicitly defined to use other policies based on the 1sm or 2sm MMA instruction. The available policies are listed in [Table 14](#epi_dispatch). @@ -458,10 +460,6 @@ use other policies based on the 1sm or 2sm MMA instruction. The available polici using ElementAccumulator = float; // Epilogue computation's precision type using ElementCompute = float; - // Cluster size for multicast - using ClusterShape_MNK = Shape<_4,_4,_1>; - // Collective Epilogue takes the output tile shape for 1 CTA - using PerSmTileShape_MNK = Shape<_128,_256,_128>; // // Construct CollectiveEpilogue @@ -469,7 +467,7 @@ use other policies based on the 1sm or 2sm MMA instruction. The available polici using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec - PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + MmaTileShape_MNK, ClusterShape_MNK, // MMA tile shape, and cluster shape cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue ElementC, GmemLayoutC, AlignC, // C tensor description @@ -499,12 +497,12 @@ Typically, GmemLayoutSFD would be same as the GmemLayoutD. using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassBlockScaledTensorOp, // Arch and Tensorop spec - PerSmTileShape_MNK, ClusterShape_MNK, // Epilogue tile shape, and cluster shape + MmaTileShape_MNK, ClusterShape_MNK, // MMA tile shape, and cluster shape cutlass::epilogue::collective::EpilogueTileAuto, // Epilogue subtile shape. Auto will find a suitable tile shape ElementAccumulator, ElementCompute, // Mma instr's accumulator type and compute precision for epilogue ElementC, GmemLayoutC, AlignC, // C tensor description ElementD, GmemLayoutD, AlignD, // D tensor description - cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + cutlass::epilogue::TmaWarpSpecialized2Sm // Epilogue schedule policy FusionOperation // <================================== Pass the fusion config into epilogue builder. >::CollectiveOp; ``` diff --git a/media/docs/fundamental_types.md b/media/docs/fundamental_types.md index e50cd384..3bfc4453 100644 --- a/media/docs/fundamental_types.md +++ b/media/docs/fundamental_types.md @@ -32,8 +32,8 @@ CUTLASS defines classes for the following numeric data types. * `type_erased_dynamic_float4_t`: Type agnostic 4 bits signed float allowing the user to provide a specific datatype as runtime argument. * `mx_float8_t` or `mx_float8_t` : Block scaled data type with fp8 element type and float_ue8m0_t scale factor and vector size of 32. * `mx_float6_t` or `mx_float6_t` : Block scaled data type with fp6 element type and float_ue8m0_t scale factor and vector size of 32. -* `mx_float6_t` : Block scaled data type with signed e2m1 element type and float_ue8m0_t scale factor and vector size of 32. -* `nv_float4_t` : Block scaled data type with signed e2m1 element type and float_ue8m0_t scale factor and vector size of 16. +* `mx_float4_t` : Block scaled data type with signed e2m1 element type and float_ue8m0_t scale factor and vector size of 32. +* `nv_float4_t` : Block scaled data type with signed e2m1 element type and float_ue4m3_t scale factor and vector size of 16. * `complex`: defines complex-valued data type based on the supplied real-valued numeric type Numeric types in CUTLASS may be used in both host and device code and are intended to function diff --git a/media/docs/profiler.md b/media/docs/profiler.md index 736344b4..057fd2d8 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -308,6 +308,9 @@ GEMM [int] --cluster_m,--cluster-shape::m Cluster shape in the M dimension [int] --cluster_n,--cluster-shape::n Cluster shape in the N dimension [int] --cluster_k,--cluster-shape::k Cluster shape in the K dimension + [int] --cluster_m_fallback,--cluster-shape-fallback::m Fallback cluster shape in the M dimension + [int] --cluster_n_fallback,--cluster-shape-fallback::n Fallback cluster shape in the N dimension + [int] --cluster_k_fallback,--cluster-shape-fallback::k Fallback cluster shape in the K dimension [int] --stages,--threadblock-stages Number of stages of threadblock-scoped matrix multiply [int] --warps_m,--warp-count::m Number of warps within threadblock along the M dimension [int] --warps_n,--warp-count::n Number of warps within threadblock along the N dimension @@ -320,6 +323,7 @@ GEMM [enum] --raster_order={heuristic|H|along_m|M|along_n|N} If supported by kernel, sets the tile raster direction [int] --swizzle_size={1,2,4,8} If supported by kernel, sets the 2D tile swizzle extent (In Hopper, other values will be rounded down to the nearest supported value) [int] --use_pdl,--use-pdl Use PDL (true, false) + [int] --enable_sm90_mixed_dtype_shuffle_test If true, the profiler will test SM90 mixed input kernels that can use shuffled input layouts for better performance [enum] --runtime_input_datatype_a Runtime data type for A matrix, narrow-precision only (e4m3, e5m2, e3m2, e2m3, e2m1) [enum] --runtime_input_datatype_b Runtime data type for B matrix, narrow-precision only (e4m3, e5m2, e3m2, e2m3, e2m1) @@ -360,11 +364,12 @@ Profile when execution is performed on device 0 and the C tensor is located on a $ cutlass_profiler --device=0 --allocations=C:1,D:2 --operation=Gemm --m=1024 --n=1024 --k=128 ``` -The format of tensor argument is followed by `:`. The type could be `f32` as 32-bit floating point, `s8` as 8-bit signed integer, etc. The available types can be referred to the `NumericTypeID_enumerants` in [util.cu](tools/library/src/util.cu). The layout could be `row` or `column`. +The format of tensor argument is followed by `:`. The type could be `f32` as 32-bit floating point, `s8` as 8-bit signed integer, etc. The available types can be referred to the `NumericTypeID_enumerants` in [util.cu](tools/library/src/util.cu). The layout could be `row` or `column`. If `--enable_sm90_mixed_dtype_shuffle_test=true` is used, the actual layout of the narrow data type matrix is a shuffled layout, neither `row` nor `column`. In addition to encoded data types, CUTLASS profiler allows non-encoded generic data types, namely `f8`, `f6`, and `f4`, with corresponding encoding specified through GEMM input argument: `--runtime_input_datatype_a` and `--runtime_input_datatype_b`. Currently, six encoding schemes are supported: `e4m3`, `e5m2`, `e3m2`, `e2m3`, and `e2m1`. -Cluster shapes can be statically set to `Shape;` and specified via runtime arguments: `cluster_m`, `cluster_n` and `cluster_k` in CUTLASS profiler. One may refer to our CUTLASS Example [73_blackwell_gemm_flexible_cluster](../../examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for more details of the this feature. +Cluster shapes can be statically set to `Shape;` and specified via runtime arguments: `cluster_m`, `cluster_n` and `cluster_k` in CUTLASS profiler. In addition to preferred cluster shapes, a user can also specify fallback cluster shapes via runtime arguments: `cluster_m_fallback`, `cluster_n_fallback` and `cluster_k_fallback` in CUTLASS profiler. Those fallback cluster shapes are smaller shapes than the preferred ones for the hardware to assign when there is no chance to issue a larger preferred CGA cluster to the GPU. There are several rules for using a flexible CGA: 1) Preferred CGA size should be divisible by fallback CGA size. 2) Grid dim should be divisible by preferred CGA size. 3) Preferred CGA and fallback CGA must have the same depth (cluster_dim.z must be equal). One may refer to our CUTLASS Example [73_blackwell_gemm_flexible_cluster](../../examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu) for more details of the this feature. +Please be noted that this feature (flexible cluster shapes within a single grid) is only applicable to `sm100a` kernels. The hardware will rasterize into a single cluster shape for those kernels that do not support this feature even with preferred or fallback cluster shapes assigned. CUTLASS 3.x kernels for Hopper and Blackwell also support a new feature called programatic dependent launch (PDL). This can be enabled with `--use-pdl`, and can overlap the epilogue of the prior kernel with the prologue of the next kernel. This can effectively hide kernel prologues. Using PDL can improve performance for back to back GEMMs. See [dependent kernel launch](dependent_kernel_launch.md) for more information. CUDA graphs can also be used (`--use-cuda-graphs`) with PDL to ensure that smaller kernels are enqueued back-to-back on a stream. @@ -585,6 +590,9 @@ Conv2d [int] --cluster_m,--cluster-shape::m Cluster shape in the M dimension [int] --cluster_n,--cluster-shape::n Cluster shape in the N dimension [int] --cluster_k,--cluster-shape::k Cluster shape in the K dimension + [int] --cluster_m_fallback,--cluster-shape-fallback::m Fallback cluster shape in the M dimension + [int] --cluster_n_fallback,--cluster-shape-fallback::n Fallback cluster shape in the N dimension + [int] --cluster_k_fallback,--cluster-shape-fallback::k Fallback cluster shape in the K dimension [int] --stages,--threadblock-stages Number of stages of threadblock-scoped matrix multiply [int] --warps_m,--warp-count::m Number of warps within threadblock along the M dimension [int] --warps_n,--warp-count::n Number of warps within threadblock along the N dimension diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index dd1b0c6f..cfcd5df1 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -672,11 +672,8 @@ The kernel starts with setting up datatypes and cluster shapes. using ElementAccumulator = float; using ElementCompute = float; using ElementBias = cutlass::half_t; - using ClusterTileShape = cute::Shape<_128,_64,Int<128 / sizeof(ElementA)>>; - using ClusterShape = Shape<_1,_1,_1>; - using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_1,_1,_1>{})); - using OutputCtaShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); - using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); + using MmaTileShape = cute::Shape<_128,_64,Int<128 / sizeof(ElementA)>>; + using ClusterShape = cute::Shape<_1,_1,_1>; ``` The epilogue needs to be instantiated first as the mainloop collective builder takes the shared memory budget of epilogue in the template parameter list. The 3.x epilogue collective builder API has not changed @@ -688,13 +685,12 @@ for Blackwell, so the epilogue fusion is built in a same way as an SM90 epilogue using FusionOperation = cutlass::epilogue::fusion::LinearCombination< ElementD, ElementCompute, - ElementC, - ElementBias + ElementC >; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - OutputCtaShape, ClusterShape, + MmaTileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, LayoutC, 16 / sizeof(ElementC), @@ -728,8 +724,6 @@ dispatch policies can be in [blackwell_functionality.md](./blackwell_functionali >; ``` -It is worth noting that the mainloop builder takes `MmaTileShape` while the epilogue builder takes `OutputCtaShape`. - Instantiating a blockscaled GEMM kernel is slightly different. Referring to an [MXFP8 GEMM](./../../test/unit/gemm/device/sm100_gemm_mxf8_mxf8_mxf8_tensor_op_f32_auto.cu) sample unit test, it takes a different tensor operation class: ```c++ @@ -742,10 +736,10 @@ are needed in the mainloop builder: ```c++ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, - ElementA, GmemLayoutA, 16, - ElementB, GmemLayoutB, 16, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, ElementAccumulator, - MmaTileShape_MNK, ClusterShape_MNK, + MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, cutlass::gemm::KernelScheduleAuto >::CollectiveOp; diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index 83a80e81..1e8f5774 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -532,29 +532,12 @@ def tuple_factory_(input_tuple, dtype, constants=[0,1]): if first_non_empty_base is None: first_non_empty_base = [] - # Determine whether or not add an additional byte for empty base classes - additional_byte = False - # Special case for constant tuple - if first_non_empty_base is None: - additional_byte = False - else: - for base in first_non_empty_base: - if base in empty_bases: - additional_byte = True - break - - if additional_byte: - ctype_fields = [("empty_byte", EmptyByte), ] + ctype_fields - # Create the ctype tuple class TupleType(ctypes.Structure): _fields_ = ctype_fields def __init__(self, args) -> None: - if additional_byte: - fields = self._fields_[1:] - else: - fields = self._fields_ + fields = self._fields_ assert len(fields) == len(args) for field, arg in zip(fields, args): diff --git a/python/cutlass_library/conv3x_emitter.py b/python/cutlass_library/conv3x_emitter.py index 46cb56d0..459df607 100644 --- a/python/cutlass_library/conv3x_emitter.py +++ b/python/cutlass_library/conv3x_emitter.py @@ -69,7 +69,7 @@ using ${operation_name}_epilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ${arch}, ${opcode_class_epi}, - ${output_cta_tile_shape}, // output cta tile shape + ${mma_tile_shape}, // mma tile shape ${cluster_shape}, // cluster shape ${epi_tile_mn}, ${element_accumulator}, @@ -109,26 +109,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal< def arch_number_to_type(self, arch: int) -> str: return f"cutlass::arch::Sm{arch}" - def output_cta_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: - # For all three kinds of convolutions, the tile shape's K mode - # differs from GEMM in that needs to be wrapped in a Shape. - # For Wgrad convolutions specifically, - # the N tile shape also needs to be wrapped in a Shape. - m_template = 'cute::_${cta_m}' - if operation.conv_kind == ConvKind.Wgrad: - n_template = 'cute::Shape' - else: - n_template = 'cute::_${cta_n}' - k_template = 'cute::Shape' - - output_cta_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' - values = { - 'cta_m': cta_m, - 'cta_n': cta_n, - 'cta_k': cta_k - } - return Template(output_cta_tile_shape_template).substitute(values) - def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: mma_m = cta_m mma_n = cta_n @@ -223,7 +203,6 @@ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal< 'element_accumulator': DataTypeTag[operation.accumulator_type()], 'opcode_class': opcode_class, 'arch': self.arch_number_to_type(operation.arch), - 'output_cta_tile_shape': self.output_cta_tile_shape(operation, cta_m, cta_n, cta_k), 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k), 'cluster_shape': self.cluster_shape(operation), 'opcode_class_epi': opcode_class_epi, diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 96733d60..52598d73 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -90,19 +90,32 @@ def hash_cutlass_string(input_string): def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b): # Define a dictionary mapping the detected types to runtime values datatype_map = { - '_f4_': '_' + runtime_datatype_a + '_', - '_f6_': '_' + runtime_datatype_b + '_', - '_f8_': '_' + runtime_datatype_a + '_', + 'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, } - # Use regex to identify and replace _f4_, _f6_, or _f8_ in the kernel name - def substitute(match): - datatype = match.group(0) # This is the matched "_f4_", "_f6_", or "_f8_" - return datatype_map.get(datatype, datatype) # Replace or leave as is + # Regular expression to detect all the keys in datatype_map + pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')') + + # Replace detected patterns using the dictionary + updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name) - # Regex to find "_f4_", "_f6_", or "_f8_" in the hashed_kernel_name - updated_kernel_name = re.sub(r'_f4_|_f6_|_f8_', substitute, hashed_kernel_name) - return updated_kernel_name # This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k. diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 5cc4f8b4..2374a131 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -64,17 +64,15 @@ class GemmOperation: def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, - tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False - - , ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None - - ): + tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False, + ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None): kinds_3x = { GemmKind.Universal3x, GemmKind.SparseUniversal3x, GemmKind.BlockScaledUniversal3x, - GemmKind.GroupedGemmUniversal3x, + GemmKind.GroupedUniversal3x, + GemmKind.GroupedBlockScaledUniversal3x, } self.is_3x = gemm_kind in kinds_3x self.prefix = "3x" if self.is_3x else "" @@ -87,13 +85,11 @@ class GemmOperation: self.C = C self.D = D - - if self.gemm_kind == GemmKind.BlockScaledUniversal3x: + if is_block_scaled(gemm_kind): self.ScaleFactorA = ScaleFactorA self.ScaleFactorB = ScaleFactorB self.ScaleFactorD = ScaleFactorD["tensor"] self.ScaleFactorVectorSize = ScaleFactorD["vector_size"] - if self.D == None: self.D = self.C @@ -239,13 +235,13 @@ class GemmOperation: element_c = DataTypeNames[self.C.element], element_d = DataTypeNames[self.D.element], core_name = self.core_name()) - - if self.gemm_kind == GemmKind.BlockScaledUniversal3x: + + if is_block_scaled(self.gemm_kind): d_type_names = DataTypeNames[self.D.element] - + if self.ScaleFactorD.element != DataType.void: d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names - + extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( element_sfa = DataTypeNames[self.ScaleFactorA], element_a = DataTypeNames[self.A.element], @@ -255,7 +251,7 @@ class GemmOperation: element_c = DataTypeNames[self.C.element], element_d = d_type_names, core_name = self.core_name()) - + if self.mixed_input_mode != None: extended_name = extended_name + self.mixed_input_mode_name() return extended_name @@ -298,8 +294,8 @@ class GemmOperation: # Generates a short string representing underlying epilogue schedule type def epilogue_schedule_name_3x(self): - - if self.gemm_kind == GemmKind.BlockScaledUniversal3x: + + if is_block_scaled(self.gemm_kind): if self.ScaleFactorD.element != DataType.void: return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout] @@ -779,7 +775,7 @@ class EmitGemmUniversal3xInstance: using ${operation_name}_epilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ${arch}, ${opcode_class_epi}, - cute::Shape, + cute::Shape, cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, ${epi_tile_mn}, ${element_accumulator}, ${element_epilogue}, @@ -797,7 +793,7 @@ using ${operation_name}_mainloop = ${element_a}, ${layout_a}, ${align_a}, ${element_b}, ${layout_b}, ${align_b}, ${element_accumulator}, - cute::Shape, + cute::Shape, cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, ${stages}, ${kernel_schedule} @@ -855,7 +851,7 @@ ${compile_guard_end} @staticmethod def pointerize_if_grouped(operation, layout): - return layout if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else layout + "* " + return layout if not is_grouped(operation.gemm_kind) else layout + "* " @staticmethod def problem_shape(operation): @@ -863,7 +859,7 @@ ${compile_guard_end} grouped_gemm_shape_type = "cute::Shape" grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" - return gemm_shape_type if operation.gemm_kind != GemmKind.GroupedGemmUniversal3x else grouped_gemm_shape_type + return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type def emit(self, operation): _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") @@ -874,18 +870,12 @@ ${compile_guard_end} opcode_class_main = operation.tile_description.math_instruction.opcode_class opcode_class_epi = opcode_class_main - if opcode_class_main == OpcodeClass.BlockScaledTensorOp: - if operation.epilogue_schedule != EpilogueScheduleType.NoSmemWarpSpecialized: - opcode_class_epi = OpcodeClass.TensorOp - - tile_shape = operation.tile_description.tile_shape instruction_shape = operation.tile_description.math_instruction.instruction_shape cluster_m = operation.tile_description.cluster_shape[0] cluster_n = operation.tile_description.cluster_shape[1] - tile_shape_main_m, tile_shape_main_n, tile_shape_main_k = tile_shape - tile_shape_epi_m, tile_shape_epi_n, tile_shape_epi_k = tile_shape + tile_shape_m, tile_shape_n, tile_shape_k = tile_shape # account for static/dynamic cluster shapes cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0] @@ -902,10 +892,8 @@ ${compile_guard_end} if opcode_class_main in [OpcodeClass.TensorOp , OpcodeClass.BlockScaledTensorOp ]: - tile_shape_main_m = instruction_shape[0] - tile_shape_main_n = instruction_shape[1] - tile_shape_epi_m = cta_m - tile_shape_epi_n = cta_n + tile_shape_m = instruction_shape[0] + tile_shape_n = instruction_shape[1] # stage count set to zero indicates builder automatic stage selection @@ -930,35 +918,36 @@ ${compile_guard_end} } epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) - if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void: + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: epilogue_functor = self.emit_block_scale_epilogue_functor(operation) - + else: epilogue_functor = self.epilogue_functor.emit_declaration() - - if operation.gemm_kind == GemmKind.BlockScaledUniversal3x and operation.ScaleFactorD.element != DataType.void: + + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: epilogue_functor = self.emit_block_scale_epilogue_functor(operation) - + # # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] - is_no_smem_epilogue = operation.epilogue_schedule == EpilogueScheduleType.NoSmemWarpSpecialized if opcode_class_main == OpcodeClass.BlockScaledTensorOp: - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: + is_no_smem_epilogue = operation.epilogue_schedule in [EpilogueScheduleType.NoSmemWarpSpecialized1Sm, EpilogueScheduleType.NoSmemWarpSpecialized2Sm] + grouped = is_grouped(operation.gemm_kind) + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): epi_tile_mn = "cute::Shape" if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized1Sm] - if cta_n == 256 and operation.kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped): epi_tile_mn = "cute::Shape" if not is_no_smem_epilogue: - epilogue_schedule_type = EpilogueScheduleTag[EpilogueScheduleType.TmaWarpSpecialized2Sm] + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' - + operation_name_str = operation.procedural_name() layout_a_str = LayoutTag[instance_layout_A] @@ -1041,12 +1030,9 @@ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape( 'opcode_class_main': OpcodeClassTag[opcode_class_main], 'opcode_class_epi': OpcodeClassTag[opcode_class_epi], 'arch': "cutlass::arch::Sm%d" % operation.arch, - 'tile_shape_epi_m': str(tile_shape_epi_m), - 'tile_shape_epi_n': str(tile_shape_epi_n), - 'tile_shape_epi_k': str(tile_shape_epi_k), - 'tile_shape_main_m': str(tile_shape_main_m), - 'tile_shape_main_n': str(tile_shape_main_n), - 'tile_shape_main_k': str(tile_shape_main_k), + 'tile_shape_m': str(tile_shape_m), + 'tile_shape_n': str(tile_shape_n), + 'tile_shape_k': str(tile_shape_k), 'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int", 'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int", 'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int", @@ -1396,7 +1382,8 @@ class EmitGemmConfigurationLibrary: GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, GemmKind.Grouped: EmitGemmGroupedInstance, - GemmKind.GroupedGemmUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance, } self.gemm_kind_wrappers = { @@ -1409,7 +1396,8 @@ class EmitGemmConfigurationLibrary: GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', GemmKind.Grouped: 'GemmGroupedOperation', - GemmKind.GroupedGemmUniversal3x: 'GroupedGemmUniversal3xOperation' + GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation', + GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation', } self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index a4cf5f90..d70f9ee8 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -217,8 +217,7 @@ def CreateGemmUniversal3xOperator( gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"] gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]), "vector_size" : data_type["sfd_type"]["vector_size"]} - gemm_kind = GemmKind.BlockScaledUniversal3x - + assert is_block_scaled(gemm_kind) A_dtype = data_type["a_type"] B_dtype = data_type["b_type"] @@ -254,9 +253,6 @@ def CreateGemmUniversal3xOperator( return operations -def is_grouped(gemm_kind): - return gemm_kind == GemmKind.GroupedGemmUniversal3x - # Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts def CreateSparseGemmUniversal3xOperator( manifest, layouts, tile_descriptions, data_types, @@ -6654,11 +6650,13 @@ def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int sm100_cluster_shape_1sm = [ [4,4,1] + , DynamicClusterShape ] sm100_cluster_shape_2sm = [ # cluster_m % 2 == 0 for 2sm [4,4,1] + , DynamicClusterShape ] def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): @@ -6718,6 +6716,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): ] cluster_shapes_1sm = [[1,2,1], [1,1,1], [1,4,1], [4,4,1] + , DynamicClusterShape ] tile_schedulers = [ @@ -6765,6 +6764,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): ] cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -7517,8 +7517,227 @@ def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmK CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): + # SM100 MMA with mixed F4/F6/F8 inputs + without block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return -def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + instruction_sizes_1sm = [ + # [64, 128, 32], + [128, 128, 32], + # [64, 256, 32], + [128, 256, 32], + ] + + instruction_sizes_2sm = [ + # [128, 128, 32], + # [128, 256, 32], + [256, 128, 32], + [256, 256, 32], + ] + + ab_types = [ + DataType.f4, DataType.f6, DataType.f8, + DataType.e2m1, DataType.e3m2, DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + min_cc = 100 + max_cc = 130 + + epi_type = DataType.f32 + + math_instructions_1sm = [] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + # Usage: + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + math_instructions_2sm = [] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + cluster_shapes_1sm = [ + # [1,2,1], + [2,1,1], + [1,1,1], + # [1,4,1], + [4,4,1] + , DynamicClusterShape + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter out some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) + + cluster_shapes_2sm = [ + [2,1,1], + # [2,2,1], + # [2,4,1], + # [4,1,1], + # [4,2,1], + [4,4,1] + , DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + if math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers) + else: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): # SM100 MMA with mixed F4/F6/F8 inputs + block scale if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return @@ -7529,7 +7748,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud ] instruction_sizes_1sm = [ - [128, 128, 32], [128, 256, 32], # Mixed F4/F6/F8 block scaled only supports M=128 for 1SM cases + [128, 128, 32], [128, 256, 32], # Block scaled kernels only support M=128 for 1SM cases ] instruction_sizes_2sm = [ @@ -7670,8 +7889,7 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud for data_type in data_types: CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]] - , tile_schedulers = tile_schedulers(data_type["sfd_type"]) - ) + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) cluster_shapes_2sm = [ [2,1,1], @@ -7766,21 +7984,21 @@ def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cud if math_inst.instruction_shape[0] == 128: CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]] - , tile_schedulers = tile_schedulers(data_type["sfd_type"]) - ) + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) else: CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]] - , tile_schedulers = tile_schedulers(data_type["sfd_type"]) - ) + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) -def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version): +def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): # SM100 MMA with F4 + block scale if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): return + grouped = is_grouped(gemm_kind) + # layouts for ABC and their alignments. layouts = [ [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], @@ -7805,7 +8023,7 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio def tile_schedulers(sfdtype): # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, # the epilogue is the traditional linear combination, for which we already have tests with stream-K. - if sfdtype["type"] == DataType.void: + if sfdtype["type"] == DataType.void or grouped: return [TileSchedulerType.Default] else: return [TileSchedulerType.Default, TileSchedulerType.StreamK] @@ -7826,6 +8044,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio if (is_runtime_datatype_a != is_runtime_datatype_b): continue + # grouped GEMM does not support runtime data type yet + if grouped and (is_runtime_datatype_a or is_runtime_datatype_b): + continue + math_instructions_1sm.append( MathInstruction( instr_size, @@ -7853,6 +8075,10 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio if (is_runtime_datatype_a != is_runtime_datatype_b): continue + # grouped GEMM does not support runtime data type yet + if grouped and (is_runtime_datatype_a or is_runtime_datatype_b): + continue + math_instructions_2sm.append( MathInstruction( instr_size, @@ -7972,15 +8198,21 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio for data_type in data_types: if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + # E2M1 x E2M1, vector size 32, E8 + # E2M1 x E2M1, vector size 16, UE4M3 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm] - fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm] + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) + + nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule] + fp4_schedule = [fp4_kernel_schedule, epi_schedule] CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] - , tile_schedulers=tile_schedulers(data_type["sfd_type"]) + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) if isFp4: CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] - , tile_schedulers=tile_schedulers(data_type["sfd_type"]) + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind ) cluster_shapes_2sm = [ @@ -8085,18 +8317,20 @@ def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_versio for data_type in data_types: if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1): data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. - # E2M1 x E2M1, vector size 32, E8 + # E2M1 x E2M1, vector size 32, E8 isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 - nvfp4_schedule = [KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto] - fp4_schedule = [KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto] + epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) + + nvfp4_schedule = [nvfp4_kernel_schedule, epi_schedule] + fp4_schedule = [fp4_kernel_schedule, epi_schedule] CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [nvfp4_schedule] - , tile_schedulers=tile_schedulers(data_type["sfd_type"]) - ) + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) if isFp4: CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, [fp4_schedule] - , tile_schedulers=tile_schedulers(data_type["sfd_type"]) - ) + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) @@ -8139,6 +8373,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): MathOperation.multiply_add)] cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] + , DynamicClusterShape ] tile_schedulers = [ @@ -8237,6 +8472,7 @@ def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): ] cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -8353,6 +8589,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): cluster_shapes_1sm = [ [1,2,1], [1,1,1], [1,4,1], [4,4,1] + , DynamicClusterShape ] tile_schedulers = [ @@ -8386,6 +8623,7 @@ def GenerateSM100_TensorOp_32b_UMMA_gemm_stream_k(manifest, cuda_version): cluster_shapes_2sm = [ [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -8431,6 +8669,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): cluster_shapes_1sm = [ [1,2,1], [1,1,1], [4,4,1] + , DynamicClusterShape ] tile_schedulers = [ @@ -8498,6 +8737,7 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): cluster_shapes_2sm = [ [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] + , DynamicClusterShape ] for math_inst in math_instructions_2sm: @@ -8554,6 +8794,125 @@ def GenerateSM100_TensorOp_16b_UMMA_gemm_stream_k(manifest, cuda_version): [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) +def GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 130 + + epi_type = DataType.f32 + + math_instructions_1sm = [ + MathInstruction( + [128, 256, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_1sm = [ + [1,2,1], [2,1,1], [1,1,1], [4,4,1] + , DynamicClusterShape + ] + + tile_schedulers = [ + TileSchedulerType.StreamK, + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [256, 256, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + cluster_shapes_2sm = [ + [2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,4,1] + , DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) def GenerateSM100(manifest, cuda_version): # @@ -8570,13 +8929,19 @@ def GenerateSM100(manifest, cuda_version): GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) # grouped GEMM - GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) - GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + + GenerateSM100_TensorOp_fp8_UMMA_gemm_stream_k(manifest, cuda_version) + + # StreamK is included in regular generation + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) # # Block Scaled Gemm # GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) ################################################################################################### @@ -8955,8 +9320,8 @@ def GenerateSM90(manifest, cuda_version): GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version) GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version) GenerateSM90_TensorOp_1684(manifest, cuda_version) - GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) - GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedGemmUniversal3x) + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 89e72f2b..bc2cc7b1 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -321,6 +321,12 @@ def is_complex(data_type): return True return False +def is_block_scaled(gemm_kind): + return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x) + +def is_grouped(gemm_kind): + return gemm_kind in (GemmKind.GroupedUniversal3x, GemmKind.GroupedBlockScaledUniversal3x) + # def get_complex_from_real(real_type): for r, c in RealComplexBijection: @@ -482,23 +488,32 @@ class KernelScheduleType(enum.Enum): TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() ImplicitTmaWarpSpecializedSm90 = enum_auto() - + TmaWarpSpecialized1SmSm100 = enum_auto() TmaWarpSpecialized2SmSm100 = enum_auto() PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() - + Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() - + KernelPtrArrayTmaWarpSpecializedCooperative = enum_auto() KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() KernelPtrArrayTmaWarpSpecializedPingpong = enum_auto() @@ -519,7 +534,7 @@ KernelScheduleTag = { KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', - + KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', @@ -530,16 +545,25 @@ KernelScheduleTag = { KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100', - + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100', KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100', KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', - + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", } # @@ -568,16 +592,25 @@ KernelScheduleSuffixes = { KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', - + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', - + KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', KernelScheduleType.KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', KernelScheduleType.KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', } class EpilogueScheduleType(enum.Enum): @@ -585,6 +618,10 @@ class EpilogueScheduleType(enum.Enum): EpilogueTransposed = enum_auto() NoSmemWarpSpecialized = enum_auto() PtrArrayNoSmemWarpSpecialized = enum_auto() + NoSmemWarpSpecialized1Sm = enum_auto() + NoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecialized1Sm = enum_auto() @@ -600,6 +637,10 @@ EpilogueScheduleTag = { EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', @@ -616,6 +657,10 @@ EpilogueScheduleSuffixes = { EpilogueScheduleType.EpilogueTransposed: '', EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', EpilogueScheduleType.TmaWarpSpecialized1Sm: '', @@ -636,6 +681,23 @@ EpilogueFunctor3xTag = { EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor', } +def to_grouped_schedule(schedule, grouped): + if not grouped: + return schedule + + group_schedule_map = { + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100, + EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + } + + return group_schedule_map[schedule] + class TileSchedulerType(enum.Enum): Default = enum_auto() Persistent = enum_auto() @@ -817,7 +879,8 @@ class GemmKind(enum.Enum): PlanarComplexArray = enum_auto() Grouped = enum_auto() BlockScaledUniversal3x = enum_auto() - GroupedGemmUniversal3x = enum_auto() + GroupedUniversal3x = enum_auto() + GroupedBlockScaledUniversal3x = enum_auto() # GemmKindNames = { @@ -830,7 +893,8 @@ GemmKindNames = { GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped", GemmKind.BlockScaledUniversal3x: "gemm_block_scaled", - GemmKind.GroupedGemmUniversal3x: "gemm_grouped", + GemmKind.GroupedUniversal3x: "gemm_grouped", + GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped_block_scaled" } # diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index 984ba33c..6e3038ec 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -489,7 +489,7 @@ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0): return [], [] - grouped = gemm_kind == GemmKind.GroupedGemmUniversal3x + grouped = is_grouped(gemm_kind) if grouped: # the following cases are unsupported by grouped GEMM if not is_aligned: diff --git a/test/self_contained_includes/CMakeLists.txt b/test/self_contained_includes/CMakeLists.txt index 7e7b0498..cc151e1a 100644 --- a/test/self_contained_includes/CMakeLists.txt +++ b/test/self_contained_includes/CMakeLists.txt @@ -75,7 +75,6 @@ set(header_files_to_check cute/container/array_subbyte.hpp cute/container/bit_field.hpp cute/container/cuda_types.hpp - cute/container/packed_tuple.hpp cute/container/tuple.hpp cute/container/type_list.hpp @@ -107,13 +106,12 @@ set(header_files_to_check cute/arch/mma_sm70.hpp cute/arch/mma_sm75.hpp cute/arch/mma_sm80.hpp - cute/arch/mma_sm80_sparse.hpp cute/arch/mma_sm90.hpp cute/arch/mma_sm90_desc.hpp cute/arch/mma_sm90_gmma.hpp cute/arch/mma.hpp cute/arch/util.hpp - + cute/arch/cluster_sm100.hpp cute/arch/copy_sm100.hpp cute/arch/copy_sm100_tma.hpp @@ -121,7 +119,7 @@ set(header_files_to_check cute/arch/mma_sm100_desc.hpp cute/arch/mma_sm100_umma.hpp # cute/arch/tmem_allocator_sm100.hpp - + # cute/atom # cute/atom/copy_atom.hpp # cute/atom/copy_traits.hpp @@ -140,10 +138,10 @@ set(header_files_to_check cute/atom/mma_traits_sm80.hpp cute/atom/mma_traits_sm90.hpp cute/atom/mma_traits_sm90_gmma.hpp - - cute/atom/mma_traits_sm100.hpp + + cute/atom/mma_traits_sm100.hpp cute/atom/partitioner.hpp - + # cutlass cutlass/aligned_buffer.h cutlass/array.h @@ -180,7 +178,6 @@ set(header_files_to_check cutlass/numeric_size.h cutlass/numeric_types.h cutlass/pitch_linear_coord.h - cutlass/predicate.h cutlass/predicate_vector.h cutlass/quaternion.h cutlass/real.h @@ -200,16 +197,16 @@ set(header_files_to_check cutlass/workspace.h cutlass/exmy_base.h cutlass/float_subbyte.h - + # cutlass/platform cutlass/platform/platform.h # cutlass/pipeline cutlass/pipeline/pipeline.hpp cutlass/pipeline/sm90_pipeline.hpp - + cutlass/pipeline/sm100_pipeline.hpp - + # cutlass/detail cutlass/detail/cluster.hpp @@ -217,18 +214,16 @@ set(header_files_to_check cutlass/detail/dependent_false.hpp cutlass/detail/helper_macros.hpp cutlass/detail/layout.hpp - cutlass/detail/mainloop_fusion_helper_bgrada.hpp cutlass/detail/mma.hpp - + cutlass/detail/sm100_blockscaled_layout.hpp - + # cutlass/arch cutlass/arch/arch.h cutlass/arch/barrier.h cutlass/arch/cache_operation.h cutlass/arch/config.h - cutlass/arch/custom_abi.h cutlass/arch/grid_dependency_control.h cutlass/arch/memory.h # cutlass/arch/memory_sm75.h @@ -248,7 +243,6 @@ set(header_files_to_check # cutlass/arch/simd_sm60.h # cutlass/arch/simd_sm61.h cutlass/arch/reg_reconfig.h - cutlass/arch/tma_operation.h cutlass/arch/wmma.h # cutlass/arch/wmma_sm70.h # cutlass/arch/wmma_sm72.h diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt index d74ed3a7..4469f43e 100644 --- a/test/unit/cute/core/CMakeLists.txt +++ b/test/unit/cute/core/CMakeLists.txt @@ -47,11 +47,9 @@ cutlass_test_unit_add_executable( math.cpp mixedbits.cpp nullspace.cpp - packed_tuple.cpp pointer.cpp reverse.cpp swizzle_layout.cpp transform.cpp tuple.cpp - tuple_find.cpp ) diff --git a/test/unit/cute/core/packed_tuple.cpp b/test/unit/cute/core/packed_tuple.cpp deleted file mode 100644 index 77584e88..00000000 --- a/test/unit/cute/core/packed_tuple.cpp +++ /dev/null @@ -1,581 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#include "cutlass_unit_test.h" - -#include - -#include -#include - -#include -#include -#include -#include -#include - -namespace pt_test { - -template -struct Nonempty { - T datum; - - Nonempty(T const& t) : datum{t} {} - - friend bool operator==(Nonempty const& lhs, Nonempty const& rhs) { - return lhs.datum == rhs.datum; - } - - friend bool operator!=(Nonempty const& lhs, Nonempty const& rhs) { - return !(lhs == rhs); - } -}; - -template -struct Empty { - template - friend bool operator==(Empty const&, Empty const&) { - return V == W; - } - - template - friend bool operator!=(Empty const& lhs, Empty const& rhs) { - return !(lhs == rhs); - } -}; - -// std::tuple -static_assert(cute::is_standard_layout_v>); // it happens to be -static_assert(cute::is_standard_layout_v>); // it happens to be -static_assert(cute::is_standard_layout_v>); // it happens to be -static_assert(not cute::is_standard_layout_v>); // it's not - -#if ! defined(CUTLASS_USE_PACKED_TUPLE) -// cute::tuple -static_assert(cute::is_standard_layout_v>); // it happens to be -static_assert(cute::is_standard_layout_v>); // it happens to be -static_assert(cute::is_standard_layout_v>); // it happens to be -static_assert(not cute::is_standard_layout_v>); // it's not -#endif // CUTLASS_USE_PACKED_TUPLE - -// cute::packed_tuple -static_assert(cute::is_standard_layout_v>); -static_assert(cute::is_standard_layout_v>); -static_assert(cute::is_standard_layout_v>); -static_assert(cute::is_standard_layout_v>); // it is -static_assert(cute::is_standard_layout_v>); // it is -static_assert(cute::is_standard_layout_v, int>>); // it is -static_assert(cute::is_standard_layout_v, Empty<0>>, int>>); // it is - -////////////////////////////////////////////////////////////////////// -// packed_tuple test starts here -////////////////////////////////////////////////////////////////////// - -template < - class ExpectedPackedType, - size_t ExpectedPackedSize, - class ... Args> -constexpr void -test_packed_type_alias([[maybe_unused]] ExpectedPackedType packed, std::tuple unpacked) -{ - using cute::packed_tuple; - - if constexpr ((cute::is_standard_layout_v && ...)) { - static_assert(cute::is_standard_layout_v>); - } - - if constexpr ((cute::is_empty_v && ...)) { - static_assert(cute::is_empty_v>); - } - - static_assert(cute::tuple_size_v> == sizeof...(Args)); - - auto test_element = [unpacked] (auto index) { - static_assert(cute::is_same_v< - std::tuple_element_t>, - std::tuple_element_t> - >); - - packed_tuple sl = cute::apply(unpacked, [](auto... a){ return cute::make_packed_tuple(a...); }); - EXPECT_EQ(std::get(unpacked), cute::get(sl)); - }; - cute::for_each(std::make_index_sequence(), test_element); -} - -void test_packed_type_aliases() { - using cute::packed_tuple; - test_packed_type_alias, 0>({}, {}); - - test_packed_type_alias, 1, int>({7}, {7}); - test_packed_type_alias, 1, double>({1.5}, {1.5}); - - // Make sure that class types are handled the same as scalar types - test_packed_type_alias>, 1, Nonempty>( - {Nonempty{7}}, {Nonempty{7}}); - test_packed_type_alias>, 1, Nonempty>( - {Nonempty{1.5}}, {Nonempty{1.5}}); - - test_packed_type_alias, 0, Empty<0>>({}, {}); - test_packed_type_alias, 0, Empty<0>, Empty<1>>( - {}, {Empty<0>{}, Empty<1>{}}); - test_packed_type_alias, 0, Empty<0>, Empty<1>, Empty<2>>( - {}, {Empty<0>{}, Empty<1>{}, Empty<2>{}}); - - test_packed_type_alias, 1, Empty<0>, int>( - {7}, {Empty<0>{}, 7}); - test_packed_type_alias, 1, int, Empty<0>>( - {7}, {7, Empty<0>{}}); - - test_packed_type_alias, 1, int, Empty<0>, Empty<1>>( - {7}, {7, Empty<0>{}, Empty<1>{}}); - test_packed_type_alias, 1, Empty<0>, int, Empty<1>>( - {7}, {Empty<0>{}, 7, Empty<1>{}}); - test_packed_type_alias, 1, Empty<0>, Empty<1>, int>( - {7}, {Empty<0>{}, Empty<1>{}, 7}); - - test_packed_type_alias, 2, int, double, Empty<0>>( - {7, 1.5}, {7, 1.5, Empty<0>{}}); - test_packed_type_alias, 2, int, Empty<0>, double>( - {7, 1.5}, {7, Empty<0>{}, 1.5}); - test_packed_type_alias, 2, int, double, Empty<0>>( - {7, 1.5}, {7, 1.5, Empty<0>{}}); - - test_packed_type_alias, 2, int, double, Empty<0>, Empty<1>>( - {7, 1.5}, {7, 1.5, Empty<0>{}, Empty<1>{}}); - test_packed_type_alias, 2, int, Empty<0>, double, Empty<1>>( - {7, 1.5}, {7, Empty<0>{}, 1.5, Empty<1>{}}); - test_packed_type_alias, 2, int, Empty<0>, Empty<1>, double>( - {7, 1.5}, {7, Empty<0>{}, Empty<1>{}, 1.5}); - test_packed_type_alias, 2, Empty<0>, int, Empty<1>, double>( - {7, 1.5}, {Empty<0>{}, 7, Empty<1>{}, 1.5}); - test_packed_type_alias, 2, Empty<0>, Empty<1>, int, double>( - {7, 1.5}, {Empty<0>{}, Empty<1>{}, 7, 1.5}); - - test_packed_type_alias, 3, Empty<0>, int, double, float>( - {7, 1.5, 2.5f}, {Empty<0>{}, 7, 1.5, 2.5f}); - test_packed_type_alias, 3, int, Empty<0>, double, float>( - {7, 1.5, 2.5f}, {7, Empty<0>{}, 1.5, 2.5f}); - test_packed_type_alias, 3, int, double, Empty<0>, float>( - {7, 1.5, 2.5f}, {7, 1.5, Empty<0>{}, 2.5f}); - test_packed_type_alias, 3, int, double, float, Empty<0>>( - {7, 1.5, 2.5f}, {7, 1.5, 2.5f, Empty<0>{}}); -} - -template -constexpr bool test_tuple_element() { - return cute::is_same_v, ExpectedElementType>; -} - -void test_tuple_elements() { - using cute::packed_tuple; - - static_assert(test_tuple_element>, 0, Empty<0>>()); - static_assert(test_tuple_element>, 0, Empty<0>>()); -} - -// A default-constructible type. -template -struct DefaultConstructible {}; - -void test_default_constructibility() { - using cute::packed_tuple; - { - [[maybe_unused]] packed_tuple<> t_p_0; - [[maybe_unused]] packed_tuple> t_p_1; - [[maybe_unused]] packed_tuple, DefaultConstructible<1>> t_p_2; - [[maybe_unused]] packed_tuple, int, DefaultConstructible<1>> t_p_3; - } -} - -void test_sizes_and_not_storing_empty_types() { - using cute::packed_tuple; - - [[maybe_unused]] packed_tuple< - int, - pt_test::Empty<0>, - double - > pt{42, pt_test::Empty<0>{}, 1.5}; - static_assert(cute::is_standard_layout_v); - // packed_result_type must only store the packed tuple, - // and not the integer_sequence(s) used to access it. - // The latter can be represented entirely at compile time as types. - struct { int i; double j; } IntDouble; - static_assert(sizeof(pt) == sizeof(IntDouble)); - - EXPECT_EQ(cute::get<0>(pt), 42); - EXPECT_EQ(cute::get<1>(pt), pt_test::Empty<0>{}); - EXPECT_EQ(cute::get<2>(pt), 1.5); - packed_tuple< - pt_test::Empty<0>, - pt_test::Empty<1>, - packed_tuple< - pt_test::Empty<0>, - pt_test::Empty<1>, - packed_tuple, packed_tuple<>> - > - > pt_empty{}; - static_assert(cute::is_empty_v); - static_assert(cute::is_standard_layout_v); - static_assert(sizeof(pt_empty) == 1); - - // Template arguments must be default constructible, - // and packed_tuple itself needs a default constructor. - [[maybe_unused]] packed_tuple< - packed_tuple>, - double, - pt_test::Empty<3>> pt2; - static_assert(cute::is_standard_layout_v); - - // cute::packed_tuple, like the original cute::tuple, does not - // promise to have working CTAD (constructor template argument - // deduction). - [[maybe_unused]] packed_tuple< - packed_tuple>, - pt_test::Empty<1> - > pt3{ - packed_tuple>{42, pt_test::Empty<0>{}}, - pt_test::Empty<1>{} - }; - static_assert(cute::is_standard_layout_v); - static_assert(cute::is_same_v< - cute::tuple_element_t<0, decltype(pt3)>, - packed_tuple>>); - static_assert(cute::is_same_v< - cute::tuple_element_t<1, decltype(pt3)>, - pt_test::Empty<1>>); - static_assert(cute::tuple_size_v> == 2u); - - packed_tuple> pt3_0 = cute::get<0>(pt3); - auto pt3_0_1 = cute::get<1>(pt3_0); - static_assert(cute::is_same_v>); - - EXPECT_EQ(cute::get<0>(cute::get<0>(pt3)), 42); - EXPECT_EQ(cute::get<1>(cute::get<0>(pt3)), pt_test::Empty<0>{}); -} - -} // namespace test - -TEST(CuTe_core, PackedTuple2) -{ - CUTLASS_TRACE_HOST("-------------------------------"); - CUTLASS_TRACE_HOST("packed_tuple"); - CUTLASS_TRACE_HOST("-------------------------------"); - - pt_test::test_packed_type_aliases(); - pt_test::test_tuple_elements(); - pt_test::test_default_constructibility(); - pt_test::test_sizes_and_not_storing_empty_types(); -} - -TEST(CuTe_core, PackedTuple2Get) { - using cute::packed_tuple; - using pt_test::Empty; - using pt_test::Nonempty; - - { - using tuple_type = packed_tuple; - tuple_type pt{42}; - static_assert(cute::tuple_size_v == 1u); - static_assert(cute::is_same_v, int>); - EXPECT_EQ(cute::get<0>(pt), 42); - cute::get<0>(pt) = 43; - EXPECT_EQ(cute::get<0>(pt), 43); - } - { - using tuple_type = packed_tuple; - tuple_type const pt{42}; - EXPECT_EQ(cute::get<0>(pt), 42); - static_assert(cute::is_same_v(pt)), int const&>); - } - { - EXPECT_EQ(cute::get<0>(packed_tuple{42}), 42); - } - - { - using tuple_type = packed_tuple>; - tuple_type pt; - static_assert(cute::tuple_size_v == 1u); - static_assert(cute::is_same_v, pt_test::Empty<0>>); - EXPECT_EQ(cute::get<0>(pt), pt_test::Empty<0>{}); - } - { - using tuple_type = packed_tuple>; - tuple_type const pt; - EXPECT_EQ(cute::get<0>(pt), pt_test::Empty<0>{}); - } - { - using tuple_type = packed_tuple>; - EXPECT_EQ(cute::get<0>(tuple_type{}), pt_test::Empty<0>{}); - } - - { - using tuple_type = packed_tuple; - tuple_type pt{1, 2.5}; - static_assert(cute::tuple_size_v == 2u); - static_assert(cute::is_same_v, int>); - static_assert(cute::is_same_v, double>); - EXPECT_EQ(cute::get<0>(pt), 1); - cute::get<0>(pt) = 2; - EXPECT_EQ(cute::get<0>(pt), 2); - EXPECT_EQ(cute::get<1>(pt), 2.5); - cute::get<1>(pt) = 3.5; - EXPECT_EQ(cute::get<1>(pt), 3.5); - } - { - using tuple_type = packed_tuple; - tuple_type const pt{1, 2.5}; - EXPECT_EQ(cute::get<0>(pt), 1); - static_assert(cute::is_same_v(pt)), int const&>); - EXPECT_EQ(cute::get<1>(pt), 2.5); - static_assert(cute::is_same_v(pt)), double const&>); - } - { - using tuple_type = packed_tuple; - EXPECT_EQ(cute::get<0>(tuple_type{1, 2.5}), 1); - EXPECT_EQ(cute::get<1>(tuple_type{1, 2.5}), 2.5); - } - - { - using tuple_type = packed_tuple, double>; - tuple_type pt{Empty<0>{}, 2.5}; - static_assert(cute::tuple_size_v == 2u); - static_assert(cute::is_same_v, Empty<0>>); - static_assert(cute::is_same_v, double>); - EXPECT_EQ(cute::get<0>(pt), Empty<0>{}); - EXPECT_EQ(cute::get<1>(pt), 2.5); - cute::get<1>(pt) = 3.5; - EXPECT_EQ(cute::get<1>(pt), 3.5); - } - { - using tuple_type = packed_tuple, double>; - tuple_type const pt{Empty<0>{}, 2.5}; - EXPECT_EQ(cute::get<0>(pt), Empty<0>{}); - static_assert(cute::is_same_v(pt)), Empty<0>>); - EXPECT_EQ(cute::get<1>(pt), 2.5); - static_assert(cute::is_same_v(pt)), double const&>); - } - { - using tuple_type = packed_tuple, double>; - EXPECT_EQ(cute::get<0>(tuple_type{Empty<0>{}, 2.5}), Empty<0>{}); - EXPECT_EQ(cute::get<1>(tuple_type{Empty<0>{}, 2.5}), 2.5); - } - - { - using tuple_type = packed_tuple>; - tuple_type pt{1, 2.5, Nonempty{3.25f}}; - static_assert(cute::tuple_size_v == 3u); - static_assert(cute::is_same_v, int>); - static_assert(cute::is_same_v, double>); - static_assert(cute::is_same_v, Nonempty>); - EXPECT_EQ(cute::get<0>(pt), 1); - EXPECT_EQ(cute::get<1>(pt), 2.5); - EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); - - cute::get<0>(pt) = 42; - EXPECT_EQ(cute::get<0>(pt), 42); - cute::get<1>(pt) = 4.5; - EXPECT_EQ(cute::get<1>(pt), 4.5); - cute::get<2>(pt) = Nonempty{3.75f}; - EXPECT_EQ(cute::get<2>(pt), Nonempty{3.75f}); - } - { - using tuple_type = packed_tuple>; - tuple_type const pt{1, 2.5, Nonempty{3.25f}}; - EXPECT_EQ(cute::get<0>(pt), 1); - EXPECT_EQ(cute::get<1>(pt), 2.5); - EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); - } - { - using tuple_type = packed_tuple>; - EXPECT_EQ((cute::get<0>(tuple_type{1, 2.5, Nonempty{3.25f}})), 1); - EXPECT_EQ((cute::get<1>(tuple_type{1, 2.5, Nonempty{3.25f}})), 2.5); - EXPECT_EQ((cute::get<2>(tuple_type{1, 2.5, Nonempty{3.25f}})), Nonempty{3.25f}); - } - - { - using tuple_type = packed_tuple, Nonempty>; - packed_tuple, Nonempty> pt{1, Empty<0>{}, Nonempty{3.25f}}; - static_assert(cute::tuple_size_v == 3u); - static_assert(cute::is_same_v, int>); - static_assert(cute::is_same_v, Empty<0>>); - static_assert(cute::is_same_v, Nonempty>); - EXPECT_EQ(cute::get<0>(pt), 1); - EXPECT_EQ(cute::get<1>(pt), Empty<0>{}); - EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); - - cute::get<0>(pt) = 42; - EXPECT_EQ(cute::get<0>(pt), 42); - cute::get<2>(pt) = Nonempty{3.75f}; - EXPECT_EQ(cute::get<2>(pt), Nonempty{3.75f}); - } - { - using tuple_type = packed_tuple, Nonempty>; - tuple_type const pt{1, Empty<0>{}, Nonempty{3.25f}}; - EXPECT_EQ(cute::get<0>(pt), 1); - EXPECT_EQ(cute::get<1>(pt), Empty<0>{}); - EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); - } - { - using tuple_type = packed_tuple, Nonempty>; - EXPECT_EQ((cute::get<0>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), 1); - EXPECT_EQ((cute::get<1>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), Empty<0>{}); - EXPECT_EQ((cute::get<2>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), Nonempty{3.25f}); - } -} - -namespace pt_test { - -// An empty class type to which Empty is convertible. -template -struct ConvertibleFromEmpty { - constexpr ConvertibleFromEmpty() = default; - constexpr ConvertibleFromEmpty(Empty) {} - - template - friend constexpr bool operator==(ConvertibleFromEmpty const&, ConvertibleFromEmpty const&) { - return Value == OtherValue; - } - - template - friend constexpr bool operator!=(ConvertibleFromEmpty const& lhs, ConvertibleFromEmpty const& rhs) { - return !(lhs == rhs); - } -}; - -} // end namespace pt_test - -TEST(CuTe_core, PackedTupleConstexprDefaultConstruction) { - // Make sure that packed_tuple's default constructor is constexpr. - // MSVC makes this a bit more challenging than usual. - - using pt_test::Empty; - { - [[maybe_unused]] constexpr cute::detail::ESO_t> eso1{}; - [[maybe_unused]] constexpr cute::detail::ESO_t eso2{}; - } - { - [[maybe_unused]] constexpr cute::detail::ESO_t, Empty<1>> eso0{}; - [[maybe_unused]] constexpr cute::detail::ESO_t> eso1{}; - [[maybe_unused]] constexpr cute::detail::ESO_t, int64_t> eso2{}; - [[maybe_unused]] constexpr cute::detail::ESO_t eso3{}; - } -} - -TEST(CuTe_core, PackedTupleConvertingConstruction) { - using cute::packed_tuple; - using pt_test::ConvertibleFromEmpty; - using pt_test::Empty; - using pt_test::Nonempty; - - { - using tuple_type = cute::tuple>; - [[maybe_unused]] tuple_type t(7); - EXPECT_EQ(cute::get<0>(t), Nonempty(7)); - } - { - using tuple_type = packed_tuple>; - [[maybe_unused]] tuple_type t(7); - EXPECT_EQ(cute::get<0>(t), Nonempty(7)); - } - { - using tuple_type = cute::tuple>; - [[maybe_unused]] tuple_type t(Empty<0>{}); - EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); - } - { - using tuple_type = packed_tuple>; - [[maybe_unused]] tuple_type t(Empty<0>{}); - EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); - } - - { - using tuple_type = cute::tuple>; - [[maybe_unused]] tuple_type t(1.5f, 7); - EXPECT_EQ(cute::get<0>(t), 1.5f); - EXPECT_EQ(cute::get<1>(t), Nonempty(7)); - } - { - using tuple_type = packed_tuple>; - [[maybe_unused]] tuple_type t(1.5f, 7); - EXPECT_EQ(cute::get<0>(t), 1.5f); - EXPECT_EQ(cute::get<1>(t), Nonempty(7)); - } - - { - using tuple_type = cute::tuple, Nonempty>; - [[maybe_unused]] tuple_type t(Empty<0>{}, 7); - EXPECT_EQ(cute::get<0>(t), Empty<0>{}); - EXPECT_EQ(cute::get<1>(t), Nonempty(7)); - } - { - using tuple_type = packed_tuple, Nonempty>; - [[maybe_unused]] tuple_type t(Empty<0>{}, 7); - EXPECT_EQ(cute::get<0>(t), Empty<0>{}); - EXPECT_EQ(cute::get<1>(t), Nonempty(7)); - } - - { - using tuple_type = cute::tuple, Nonempty>; - [[maybe_unused]] tuple_type t(Empty<0>{}, 7); - EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); - EXPECT_EQ(cute::get<1>(t), Nonempty(7)); - } - { - using tuple_type = packed_tuple, Nonempty>; - [[maybe_unused]] tuple_type t(Empty<0>{}, 7); - EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); - EXPECT_EQ(cute::get<1>(t), Nonempty(7)); - } - - { - using inner_tuple_type = cute::tuple>; - using outer_tuple_type = cute::tuple; - [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); - } - { - using inner_tuple_type = packed_tuple>; - using outer_tuple_type = packed_tuple; - [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); - } - { - using inner_tuple_type = cute::tuple>; - using outer_tuple_type = cute::tuple; - [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); - } - { - using inner_tuple_type = packed_tuple>; - using outer_tuple_type = packed_tuple; - [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); - } - -} - - diff --git a/test/unit/cute/core/tuple.cpp b/test/unit/cute/core/tuple.cpp index f1efb36e..ea31edd9 100644 --- a/test/unit/cute/core/tuple.cpp +++ b/test/unit/cute/core/tuple.cpp @@ -32,6 +32,13 @@ #include "cutlass_unit_test.h" #include + +#include +#include + +#include +#include +#include #include TEST(CuTe_core, Tuple) @@ -120,6 +127,11 @@ TEST(CuTe_core, Tuple) ASSERT_TRUE(sizeof(tuple_3h_m_type) == 12); ASSERT_TRUE(!std::is_empty::value); + ASSERT_TRUE(sizeof(cute::tuple<_1, _1, cute::tuple>) == 4); + ASSERT_TRUE(sizeof(cute::tuple<_1, _0, cute::tuple>) == 4); + ASSERT_TRUE(sizeof(cute::tuple<_1, cute::tuple<_1, int32_t>>) == 4); + ASSERT_TRUE(sizeof(cute::tuple<_1, cute::tuple<_0, int32_t>>) == 4); + CUTLASS_TRACE_HOST("-------------------------------"); CUTLASS_TRACE_HOST("SIMPLE TUPLE OPS"); CUTLASS_TRACE_HOST("-------------------------------"); @@ -264,3 +276,588 @@ TEST(CuTe_core, Tuple) CUTLASS_TRACE_HOST("a(_,1,_,(1,2)) = " << dice(make_coord(_,1,_,make_coord(1,2)), a)); } } + +namespace pt_test { + +template +struct Nonempty { + T datum; + + Nonempty(T const& t) : datum{t} {} + + friend bool operator==(Nonempty const& lhs, Nonempty const& rhs) { + return lhs.datum == rhs.datum; + } + + friend bool operator!=(Nonempty const& lhs, Nonempty const& rhs) { + return !(lhs == rhs); + } +}; + +template +struct Empty { + template + friend bool operator==(Empty const&, Empty const&) { + return V == W; + } + + template + friend bool operator!=(Empty const& lhs, Empty const& rhs) { + return !(lhs == rhs); + } +}; + +// std::tuple +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(cute::is_standard_layout_v>); // it happens to be +static_assert(not cute::is_standard_layout_v>); // it's not + +// cute::tuple +static_assert(cute::is_standard_layout_v>); +static_assert(cute::is_standard_layout_v>); +static_assert(cute::is_standard_layout_v>); +static_assert(cute::is_standard_layout_v>); // it is +static_assert(cute::is_standard_layout_v>); // it is +static_assert(cute::is_standard_layout_v, int>>); // it is +static_assert(cute::is_standard_layout_v, Empty<0>>, int>>); // it is + +////////////////////////////////////////////////////////////////////// +// tuple test starts here +////////////////////////////////////////////////////////////////////// + +template < + class ExpectedPackedType, + size_t ExpectedPackedSize, + class ... Args> +constexpr void +test_packed_type_alias([[maybe_unused]] ExpectedPackedType packed, std::tuple unpacked) +{ + using cute::tuple; + + if constexpr ((cute::is_standard_layout_v && ...)) { + static_assert(cute::is_standard_layout_v>); + } + + if constexpr ((cute::is_empty_v && ...)) { + static_assert(cute::is_empty_v>); + } + + static_assert(cute::tuple_size_v> == sizeof...(Args)); + + auto test_element = [unpacked] (auto index) { + static_assert(cute::is_same_v< + std::tuple_element_t>, + std::tuple_element_t> + >); + + tuple sl = cute::apply(unpacked, [](auto... a){ return cute::make_tuple(a...); }); + EXPECT_EQ(std::get(unpacked), cute::get(sl)); + }; + cute::for_each(std::make_index_sequence(), test_element); +} + +void test_packed_type_aliases() { + using cute::tuple; + test_packed_type_alias, 0>({}, {}); + + test_packed_type_alias, 1, int>({7}, {7}); + test_packed_type_alias, 1, double>({1.5}, {1.5}); + + // Make sure that class types are handled the same as scalar types + test_packed_type_alias>, 1, Nonempty>( + {Nonempty{7}}, {Nonempty{7}}); + test_packed_type_alias>, 1, Nonempty>( + {Nonempty{1.5}}, {Nonempty{1.5}}); + + test_packed_type_alias, 0, Empty<0>>({}, {}); + test_packed_type_alias, 0, Empty<0>, Empty<1>>( + {}, {Empty<0>{}, Empty<1>{}}); + test_packed_type_alias, 0, Empty<0>, Empty<1>, Empty<2>>( + {}, {Empty<0>{}, Empty<1>{}, Empty<2>{}}); + + test_packed_type_alias, 1, Empty<0>, int>( + {7}, {Empty<0>{}, 7}); + test_packed_type_alias, 1, int, Empty<0>>( + {7}, {7, Empty<0>{}}); + + test_packed_type_alias, 1, int, Empty<0>, Empty<1>>( + {7}, {7, Empty<0>{}, Empty<1>{}}); + test_packed_type_alias, 1, Empty<0>, int, Empty<1>>( + {7}, {Empty<0>{}, 7, Empty<1>{}}); + test_packed_type_alias, 1, Empty<0>, Empty<1>, int>( + {7}, {Empty<0>{}, Empty<1>{}, 7}); + + test_packed_type_alias, 2, int, double, Empty<0>>( + {7, 1.5}, {7, 1.5, Empty<0>{}}); + test_packed_type_alias, 2, int, Empty<0>, double>( + {7, 1.5}, {7, Empty<0>{}, 1.5}); + test_packed_type_alias, 2, int, double, Empty<0>>( + {7, 1.5}, {7, 1.5, Empty<0>{}}); + + test_packed_type_alias, 2, int, double, Empty<0>, Empty<1>>( + {7, 1.5}, {7, 1.5, Empty<0>{}, Empty<1>{}}); + test_packed_type_alias, 2, int, Empty<0>, double, Empty<1>>( + {7, 1.5}, {7, Empty<0>{}, 1.5, Empty<1>{}}); + test_packed_type_alias, 2, int, Empty<0>, Empty<1>, double>( + {7, 1.5}, {7, Empty<0>{}, Empty<1>{}, 1.5}); + test_packed_type_alias, 2, Empty<0>, int, Empty<1>, double>( + {7, 1.5}, {Empty<0>{}, 7, Empty<1>{}, 1.5}); + test_packed_type_alias, 2, Empty<0>, Empty<1>, int, double>( + {7, 1.5}, {Empty<0>{}, Empty<1>{}, 7, 1.5}); + + test_packed_type_alias, 3, Empty<0>, int, double, float>( + {7, 1.5, 2.5f}, {Empty<0>{}, 7, 1.5, 2.5f}); + test_packed_type_alias, 3, int, Empty<0>, double, float>( + {7, 1.5, 2.5f}, {7, Empty<0>{}, 1.5, 2.5f}); + test_packed_type_alias, 3, int, double, Empty<0>, float>( + {7, 1.5, 2.5f}, {7, 1.5, Empty<0>{}, 2.5f}); + test_packed_type_alias, 3, int, double, float, Empty<0>>( + {7, 1.5, 2.5f}, {7, 1.5, 2.5f, Empty<0>{}}); +} + +template +constexpr bool test_tuple_element() { + return cute::is_same_v, ExpectedElementType>; +} + +void test_tuple_elements() { + using cute::tuple; + + static_assert(test_tuple_element>, 0, Empty<0>>()); + static_assert(test_tuple_element>, 0, Empty<0>>()); +} + +// A default-constructible type. +template +struct DefaultConstructible {}; + +void test_default_constructibility() { + using cute::tuple; + { + [[maybe_unused]] tuple<> t_p_0; + [[maybe_unused]] tuple> t_p_1; + [[maybe_unused]] tuple, DefaultConstructible<1>> t_p_2; + [[maybe_unused]] tuple, int, DefaultConstructible<1>> t_p_3; + } +} + +void test_sizes_and_not_storing_empty_types() { + using cute::tuple; + + [[maybe_unused]] tuple< + int, + pt_test::Empty<0>, + double + > pt{42, pt_test::Empty<0>{}, 1.5}; + static_assert(cute::is_standard_layout_v); + // packed_result_type must only store the packed tuple, + // and not the integer_sequence(s) used to access it. + // The latter can be represented entirely at compile time as types. + struct { int i; double j; } IntDouble; + static_assert(sizeof(pt) == sizeof(IntDouble)); + + EXPECT_EQ(cute::get<0>(pt), 42); + EXPECT_EQ(cute::get<1>(pt), pt_test::Empty<0>{}); + EXPECT_EQ(cute::get<2>(pt), 1.5); + tuple< + pt_test::Empty<0>, + pt_test::Empty<1>, + tuple< + pt_test::Empty<0>, + pt_test::Empty<1>, + tuple, tuple<>> + > + > pt_empty{}; + static_assert(cute::is_empty_v); + static_assert(cute::is_standard_layout_v); + static_assert(sizeof(pt_empty) == 1); + + // Template arguments must be default constructible, + // and tuple itself needs a default constructor. + [[maybe_unused]] tuple< + tuple>, + double, + pt_test::Empty<3>> pt2; + static_assert(cute::is_standard_layout_v); + + // cute::tuple, like the original cute::tuple, does not + // promise to have working CTAD (constructor template argument + // deduction). + [[maybe_unused]] tuple< + tuple>, + pt_test::Empty<1> + > pt3{ + tuple>{42, pt_test::Empty<0>{}}, + pt_test::Empty<1>{} + }; + static_assert(cute::is_standard_layout_v); + static_assert(cute::is_same_v< + cute::tuple_element_t<0, decltype(pt3)>, + tuple>>); + static_assert(cute::is_same_v< + cute::tuple_element_t<1, decltype(pt3)>, + pt_test::Empty<1>>); + static_assert(cute::tuple_size_v> == 2u); + + tuple> pt3_0 = cute::get<0>(pt3); + auto pt3_0_1 = cute::get<1>(pt3_0); + static_assert(cute::is_same_v>); + + EXPECT_EQ(cute::get<0>(cute::get<0>(pt3)), 42); + EXPECT_EQ(cute::get<1>(cute::get<0>(pt3)), pt_test::Empty<0>{}); +} + +} // namespace test + +TEST(CuTe_core, PackedTuple2) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("tuple"); + CUTLASS_TRACE_HOST("-------------------------------"); + + pt_test::test_packed_type_aliases(); + pt_test::test_tuple_elements(); + pt_test::test_default_constructibility(); + pt_test::test_sizes_and_not_storing_empty_types(); +} + +TEST(CuTe_core, PackedTuple2Get) { + using cute::tuple; + using pt_test::Empty; + using pt_test::Nonempty; + + { + using tuple_type = tuple; + tuple_type pt{42}; + static_assert(cute::tuple_size_v == 1u); + static_assert(cute::is_same_v, int>); + EXPECT_EQ(cute::get<0>(pt), 42); + cute::get<0>(pt) = 43; + EXPECT_EQ(cute::get<0>(pt), 43); + } + { + using tuple_type = tuple; + tuple_type const pt{42}; + EXPECT_EQ(cute::get<0>(pt), 42); + static_assert(cute::is_same_v(pt)), int const&>); + } + { + EXPECT_EQ(cute::get<0>(tuple{42}), 42); + } + + { + using tuple_type = tuple>; + tuple_type pt; + static_assert(cute::tuple_size_v == 1u); + static_assert(cute::is_same_v, pt_test::Empty<0>>); + EXPECT_EQ(cute::get<0>(pt), pt_test::Empty<0>{}); + } + { + using tuple_type = tuple>; + tuple_type const pt; + EXPECT_EQ(cute::get<0>(pt), pt_test::Empty<0>{}); + } + { + using tuple_type = tuple>; + EXPECT_EQ(cute::get<0>(tuple_type{}), pt_test::Empty<0>{}); + } + + { + using tuple_type = tuple; + tuple_type pt{1, 2.5}; + static_assert(cute::tuple_size_v == 2u); + static_assert(cute::is_same_v, int>); + static_assert(cute::is_same_v, double>); + EXPECT_EQ(cute::get<0>(pt), 1); + cute::get<0>(pt) = 2; + EXPECT_EQ(cute::get<0>(pt), 2); + EXPECT_EQ(cute::get<1>(pt), 2.5); + cute::get<1>(pt) = 3.5; + EXPECT_EQ(cute::get<1>(pt), 3.5); + } + { + using tuple_type = tuple; + tuple_type const pt{1, 2.5}; + EXPECT_EQ(cute::get<0>(pt), 1); + static_assert(cute::is_same_v(pt)), int const&>); + EXPECT_EQ(cute::get<1>(pt), 2.5); + static_assert(cute::is_same_v(pt)), double const&>); + } + { + using tuple_type = tuple; + EXPECT_EQ(cute::get<0>(tuple_type{1, 2.5}), 1); + EXPECT_EQ(cute::get<1>(tuple_type{1, 2.5}), 2.5); + } + + { + using tuple_type = tuple, double>; + tuple_type pt{Empty<0>{}, 2.5}; + static_assert(cute::tuple_size_v == 2u); + static_assert(cute::is_same_v, Empty<0>>); + static_assert(cute::is_same_v, double>); + EXPECT_EQ(cute::get<0>(pt), Empty<0>{}); + EXPECT_EQ(cute::get<1>(pt), 2.5); + cute::get<1>(pt) = 3.5; + EXPECT_EQ(cute::get<1>(pt), 3.5); + } + { + using tuple_type = tuple, double>; + tuple_type const pt{Empty<0>{}, 2.5}; + EXPECT_EQ(cute::get<0>(pt), Empty<0>{}); + static_assert(cute::is_same_v(pt)), Empty<0>>); + EXPECT_EQ(cute::get<1>(pt), 2.5); + static_assert(cute::is_same_v(pt)), double const&>); + } + { + using tuple_type = tuple, double>; + EXPECT_EQ(cute::get<0>(tuple_type{Empty<0>{}, 2.5}), Empty<0>{}); + EXPECT_EQ(cute::get<1>(tuple_type{Empty<0>{}, 2.5}), 2.5); + } + + { + using tuple_type = tuple>; + tuple_type pt{1, 2.5, Nonempty{3.25f}}; + static_assert(cute::tuple_size_v == 3u); + static_assert(cute::is_same_v, int>); + static_assert(cute::is_same_v, double>); + static_assert(cute::is_same_v, Nonempty>); + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), 2.5); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + + cute::get<0>(pt) = 42; + EXPECT_EQ(cute::get<0>(pt), 42); + cute::get<1>(pt) = 4.5; + EXPECT_EQ(cute::get<1>(pt), 4.5); + cute::get<2>(pt) = Nonempty{3.75f}; + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.75f}); + } + { + using tuple_type = tuple>; + tuple_type const pt{1, 2.5, Nonempty{3.25f}}; + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), 2.5); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + } + { + using tuple_type = tuple>; + EXPECT_EQ((cute::get<0>(tuple_type{1, 2.5, Nonempty{3.25f}})), 1); + EXPECT_EQ((cute::get<1>(tuple_type{1, 2.5, Nonempty{3.25f}})), 2.5); + EXPECT_EQ((cute::get<2>(tuple_type{1, 2.5, Nonempty{3.25f}})), Nonempty{3.25f}); + } + + { + using tuple_type = tuple, Nonempty>; + tuple, Nonempty> pt{1, Empty<0>{}, Nonempty{3.25f}}; + static_assert(cute::tuple_size_v == 3u); + static_assert(cute::is_same_v, int>); + static_assert(cute::is_same_v, Empty<0>>); + static_assert(cute::is_same_v, Nonempty>); + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), Empty<0>{}); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + + cute::get<0>(pt) = 42; + EXPECT_EQ(cute::get<0>(pt), 42); + cute::get<2>(pt) = Nonempty{3.75f}; + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.75f}); + } + { + using tuple_type = tuple, Nonempty>; + tuple_type const pt{1, Empty<0>{}, Nonempty{3.25f}}; + EXPECT_EQ(cute::get<0>(pt), 1); + EXPECT_EQ(cute::get<1>(pt), Empty<0>{}); + EXPECT_EQ(cute::get<2>(pt), Nonempty{3.25f}); + } + { + using tuple_type = tuple, Nonempty>; + EXPECT_EQ((cute::get<0>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), 1); + EXPECT_EQ((cute::get<1>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), Empty<0>{}); + EXPECT_EQ((cute::get<2>(tuple_type{1, Empty<0>{}, Nonempty{3.25f}})), Nonempty{3.25f}); + } +} + +namespace pt_test { + +// An empty class type to which Empty is convertible. +template +struct ConvertibleFromEmpty { + constexpr ConvertibleFromEmpty() = default; + constexpr ConvertibleFromEmpty(Empty) {} + + template + friend constexpr bool operator==(ConvertibleFromEmpty const&, ConvertibleFromEmpty const&) { + return Value == OtherValue; + } + + template + friend constexpr bool operator!=(ConvertibleFromEmpty const& lhs, ConvertibleFromEmpty const& rhs) { + return !(lhs == rhs); + } +}; + +} // end namespace pt_test + +TEST(CuTe_core, PackedTupleConstexprDefaultConstruction) { + // Make sure that tuple's default constructor is constexpr. + // MSVC makes this a bit more challenging than usual. + + using pt_test::Empty; + { + [[maybe_unused]] constexpr cute::detail::ESO_t> eso1{}; + [[maybe_unused]] constexpr cute::detail::ESO_t eso2{}; + } + { + [[maybe_unused]] constexpr cute::detail::ESO_t, Empty<1>> eso0{}; + [[maybe_unused]] constexpr cute::detail::ESO_t> eso1{}; + [[maybe_unused]] constexpr cute::detail::ESO_t, int64_t> eso2{}; + [[maybe_unused]] constexpr cute::detail::ESO_t eso3{}; + } +} + +TEST(CuTe_core, PackedTupleConvertingConstruction) { + using cute::tuple; + using pt_test::ConvertibleFromEmpty; + using pt_test::Empty; + using pt_test::Nonempty; + + { + using tuple_type = cute::tuple>; + [[maybe_unused]] tuple_type t(7); + EXPECT_EQ(cute::get<0>(t), Nonempty(7)); + } + { + using tuple_type = tuple>; + [[maybe_unused]] tuple_type t(7); + EXPECT_EQ(cute::get<0>(t), Nonempty(7)); + } + { + using tuple_type = cute::tuple>; + [[maybe_unused]] tuple_type t(Empty<0>{}); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + } + { + using tuple_type = tuple>; + [[maybe_unused]] tuple_type t(Empty<0>{}); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + } + + { + using tuple_type = cute::tuple>; + [[maybe_unused]] tuple_type t(1.5f, 7); + EXPECT_EQ(cute::get<0>(t), 1.5f); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + { + using tuple_type = tuple>; + [[maybe_unused]] tuple_type t(1.5f, 7); + EXPECT_EQ(cute::get<0>(t), 1.5f); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + + { + using tuple_type = cute::tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), Empty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + { + using tuple_type = tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), Empty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + + { + using tuple_type = cute::tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + { + using tuple_type = tuple, Nonempty>; + [[maybe_unused]] tuple_type t(Empty<0>{}, 7); + EXPECT_EQ(cute::get<0>(t), ConvertibleFromEmpty<0>{}); + EXPECT_EQ(cute::get<1>(t), Nonempty(7)); + } + + { + using inner_tuple_type = cute::tuple>; + using outer_tuple_type = cute::tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } + { + using inner_tuple_type = tuple>; + using outer_tuple_type = tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } + { + using inner_tuple_type = cute::tuple>; + using outer_tuple_type = cute::tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } + { + using inner_tuple_type = tuple>; + using outer_tuple_type = tuple; + [[maybe_unused]] outer_tuple_type t(inner_tuple_type{Empty<0>{}}); + } +} + +namespace test { + +template +void test_tuple_find(Tuple const& t) { + auto index = cute::find(t); + static_assert(decltype(index)::value == ExpectedIndex); +} + +template