diff --git a/CHANGELOG.md b/CHANGELOG.md
index 63e3e80e..fc269c8b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -32,7 +32,9 @@
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
+ - Support for mixed input GEMM kernels on Hopper in the profiler.
* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
+* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
@@ -45,12 +47,16 @@
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
+ - A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
+* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
+ - A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
+ - A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
* Documentation updates:
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
- A new [functionality documentation](./media/docs/functionality.md) specifically for 3.x API comprehensively documenting all supported kernel types, data types, kernel features, minimum CUDA tookit support etc for 3.x supported architectures.
- Updates to [compatibility](./README.md#compatibility) section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures, and [Target Architecture](./README.md#Target-Architecture).
- - Support grouped GEMM in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
+ - Updates to [profiler documentation](./media/docs/profiler.md) for testing mixed input GEMM kernels on Hopper.
## [3.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.7.0) (2025-01-11)
- [Hopper blockwise scaling FP8 GEMM](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu) uses 2D scaling tensor, assigning one value per threadblock. This allows a finer-grained scaling to be applied for each output tile per gemm-k iteration. The operands and scaling tensors are loaded from global memory to shared memory using TMA and cp_async, respectively. The scaling is applied inside the mainloop. Details with figures are [here](https://github.com/NVIDIA/cutlass/pull/1932#issue-2645398439).
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b9de4f96..0b68b435 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -150,14 +150,14 @@ set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUT
set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests")
set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests")
set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest")
-set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.")
-if (CUTLASS_USE_PACKED_TUPLE)
- list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1)
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_USE_PACKED_TUPLE=1")
- message(STATUS "Make cute::tuple be the new standard-layout tuple type")
-elseif()
- message(STATUS "Use the original cute::tuple implementation that is _not_ standard-layout")
+
+if (CUTLASS_ENABLE_TESTS AND CUTLASS_ENABLE_PROFILER)
+ set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT ON)
+else()
+ set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT OFF)
endif()
+set(CUTLASS_ENABLE_PROFILER_UNIT_TESTS ${CUTLASS_ENABLE_PROFILER_UNIT_TESTS_INIT} CACHE BOOL "Enable CUTLASS Profiler-based Unit Tests")
+set(CUTLASS_ENABLE_SELF_CONTAINED_INCLUDES_CHECK ON CACHE BOOL "Enable CUTLASS check for self-contained header includes")
################################################################################
@@ -406,7 +406,7 @@ endif()
# Warnings-as-error exceptions and warning suppressions for Clang builds
if (CUTLASS_CLANG_HOST_COMPILE)
-
+
set(FLAGS_TO_ADD
"-Wno-error=implicit-int-conversion"
"-Wno-error=pass-failed"
@@ -414,13 +414,13 @@ if (CUTLASS_CLANG_HOST_COMPILE)
"-Wno-sign-conversion"
"-Wno-unused-parameter"
)
-
+
foreach(FLAG ${FLAGS_TO_ADD})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLAG}")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS "${FLAG}")
list(APPEND CUTLASS_CUDA_CLANG_FLAGS "${FLAG}")
endforeach()
-
+
endif()
if (NOT MSVC AND CUTLASS_NVCC_KEEP)
@@ -486,7 +486,7 @@ if (CUTLASS_CLANG_DEVICE_COMPILE)
link_libraries(nvidia::cudart)
link_libraries(nvidia::cuda_driver)
-
+
endif()
#Report CUDA build flags
@@ -561,7 +561,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real)
endif()
if(CUTLASS_NVCC_EMBED_PTX AND NOT CUTLASS_CLANG_DEVICE_COMPILE)
- # If we're using clang for device compilation, the ptx is inserted
+ # If we're using clang for device compilation, the ptx is inserted
# via another command line option and the `-virtual` flags will cause an error.
list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual)
endif()
@@ -922,7 +922,7 @@ function(cutlass_add_executable_tests NAME TARGET)
if (NOT __DO_NOT_LOWERCASE_TEST_NAME)
string(TOLOWER "${TESTCASE_NAME}" TESTCASE_NAME)
endif()
-
+
# The following rigmarole is needed to deal with spaces and possible quotes in
# command line arguments. The options are passed "by reference" as the actual
# variable names holding the real options. We then expand these in a way that
@@ -1007,46 +1007,51 @@ function(cutlass_generate_profiler_tests NAME)
endif()
file(STRINGS ${CUTLASS_PROFILER_REGRESSION_LIST_FILE} TEST_LIST)
-
foreach(TEST IN LISTS TEST_LIST)
-
+ set(TEMP_TEST ${TEST})
if ("${TEST}" MATCHES " *cutlass_profiler.*")
- # Generate a flattened name for the test from the test command line.
- string(REPLACE "," ";" TEST_NAME_LIST ${TEST})
- list(GET TEST_NAME_LIST 0 TEST)
- string(REGEX MATCHALL "[a-zA-Z0-9_=]+" TEST_NAME "${TEST}")
- list(FILTER TEST_NAME EXCLUDE REGEX "cutlass_profiler|mode=trace|providers=cutlass")
- list(JOIN TEST_NAME "_" TEST_NAME)
- string(REGEX REPLACE "_verification_required=(true|false)" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "_verification_providers=device" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "batch_count=" "batch" TEST_NAME "${TEST_NAME}")
- string(REPLACE "cluster_m=" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "_cluster_n=" "x" TEST_NAME "${TEST_NAME}")
- string(REGEX REPLACE "_cluster_k=[0-9]+" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "cluster_m_fallback=" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "_cluster_n_fallback=" "x" TEST_NAME "${TEST_NAME}")
- string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "=" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "_error_on_no_match" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "_error_if_nothing_is_profiled" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "kernels" "" TEST_NAME "${TEST_NAME}")
- string(REPLACE "operation" "" TEST_NAME "${TEST_NAME}")
+ # Generate a flattened name for the test from the test command line.
+ string(REPLACE "," ";" TEST_NAME_LIST ${TEMP_TEST})
+ string(REGEX REPLACE "\\*" "_" TEST_NAME "${TEMP_TEST}")
+ string(REGEX REPLACE "\\\"\\{\\\"\\\"input_params.*\\{.*\\}\\}\\\"" "" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "\\\"\\{\\\"\\\"input_params.*\\{.*\\}\\}\\\"" "" TEST "${TEST}")
+ string(REGEX REPLACE "," ";" TEST "${TEST}")
+ string(REGEX MATCHALL "[a-zA-Z0-9_=]+" TEST_NAME "${TEST_NAME}")
+ list(FILTER TEST_NAME EXCLUDE REGEX "cutlass_profiler|mode=trace|providers=cutlass")
+ list(JOIN TEST_NAME "_" TEST_NAME)
+ string(REGEX REPLACE "_verification_required=(true|false)" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "_verification_providers=device" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "batch_count=" "batch" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "cluster_m=" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "_cluster_n=" "x" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "_cluster_k=[0-9]+" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "cluster_m_fallback=" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "_cluster_n_fallback=" "x" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "_cluster_k_fallback=[0-9]+" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "runtime_input_datatype_a=" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "runtime_input_datatype_b=" "" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "verification_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "warmup_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "profiling_iterations=[0-9]+" "" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "sleep_duration=[0-9]+" "" TEST_NAME "${TEST_NAME}")
+ string(REGEX REPLACE "profiling_enabled=(true|false)" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "=" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "_error_on_no_match" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "_error_if_nothing_is_profiled" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "kernels" "" TEST_NAME "${TEST_NAME}")
+ string(REPLACE "operation" "" TEST_NAME "${TEST_NAME}")
- if (__DO_NOT_LOWERCASE_TEST_NAME)
- string(TEST_NAME_LOWER "${TEST_NAME}")
- else()
- string(TOLOWER "${TEST_NAME}" TEST_NAME_LOWER)
- endif()
+ if (NOT __DO_NOT_LOWERCASE_TEST_NAME)
+ string(TOLOWER "${TEST_NAME}" TEST_NAME)
+ endif()
- # Munge the test command
- string(REPLACE "cutlass_profiler" "" TEST "${TEST}")
- set(TEST "${TEST}" ${__CUTLASS_PROFILER_EXTRA_OPTIONS} "--junit-output=${TEST_NAME_LOWER}")
- set(TEST_COMMAND_${TEST_NAME_LOWER} "${TEST}")
- list(APPEND TEST_COMMAND_VARS ${TEST_NAME_LOWER})
+ # Munge the test command
+ string(REPLACE "cutlass_profiler" "" TEST "${TEST}")
+ set(TEST "${TEST}" ${__CUTLASS_PROFILER_EXTRA_OPTIONS} "--junit-output=${TEST_NAME}")
+ set(TEST_COMMAND_${TEST_NAME} "${TEST}")
+ list(APPEND TEST_COMMAND_VARS ${TEST_NAME})
endif()
endforeach()
@@ -1084,6 +1089,14 @@ if (CUTLASS_ENABLE_TESTS)
if (CUTLASS_ENABLE_GTEST_UNIT_TESTS)
add_dependencies(test_all test_unit)
endif()
+ if (CUTLASS_ENABLE_PROFILER_UNIT_TESTS AND CUTLASS_BUILD_FOR_PROFILER_REGRESSIONS)
+ # Generate profiler based unit test
+ cutlass_generate_profiler_tests(
+ tup
+ DEPENDEES test_unit
+ )
+ endif()
+
endif()
if (CUTLASS_INSTALL_TESTS)
diff --git a/ACTIVE_DEVELOPERS.md b/CONTRIBUTORS.md
similarity index 96%
rename from ACTIVE_DEVELOPERS.md
rename to CONTRIBUTORS.md
index 6ae47b43..1ef06a36 100644
--- a/ACTIVE_DEVELOPERS.md
+++ b/CONTRIBUTORS.md
@@ -27,7 +27,7 @@ Siyu Liu
Richard Cai
Vikas Gupta
Ethan Yan
-Vijay Thakkar (CUTLASS 3.x founding member)
+Vijay Thakkar (CUTLASS 3.x and CuTe founding member)
Cris Cecka (CuTe and CUTLASS 3.x founding member)
Lawrence Ryan
Qun Song
diff --git a/README.md b/README.md
index f9f23c08..ada18b39 100644
--- a/README.md
+++ b/README.md
@@ -43,23 +43,23 @@ architecture.
CUTLASS 3.8 is the first release that supports the NVIDIA Blackwell SM100 architecture.
For a background on Blackwell's new features, please consult the PTX documentation for CUDA 12.8.
-* Support for new CuTe building blocks specifically for Blackwell architecture:
+* Support for new CuTe building blocks specifically for Blackwell SM100 architecture:
- [5th generation Blackwell Tensor Core instructions (TCGen05)](./include/cute/atom/mma_traits_sm100.hpp) via CuTe MMA atoms.
- Extensions to [Tensor Memory Accelerator](./include/cute/atom/copy_traits_sm100_tma.hpp) via CuTe Copy atoms.
- - Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp#L290) across CuTe as a first class data locale.
+ - Exposure of Blackwell's new tensor memory (note: distinct from TMA) as [`tmem`](./include/cute/pointer.hpp) across CuTe as a first class data locale.
- Exposure of [`tmem->rmem`, `rmem->tmem` and `smem->tmem data movement instructions`](./include/cute/atom/copy_traits_sm100.hpp) as copy atoms in CuTe.
- [`make_tmem_copy()`](./include/cute/atom/copy_traits_sm100.hpp) utility method to ease creation of tiled copies for tmem copy atoms.
- Support for [new variants of LDSM on Blackwell](./include/cute/atom/copy_traits_sm100.hpp) via CuTe Copy atoms.
-* Support for new CUTLASS building blocks specifically for Blackwell architecture:
+* Support for new CUTLASS building blocks specifically for Blackwell SM100 architecture:
- Various narrow precision [FP4, FP6, and FP8](./include/cutlass/exmy_base.h) formats as well as their [block-scaled variants NVFP4, MXFP4, MXFP6, and MXFP8](./include/cutlass/float_subbyte.h)
- [Pipelines that implement Blackwell specific synchronization](./include/cutlass/pipeline/sm100_pipeline.hpp).
- [Cluster launch control API supporting preferred and fallback cluster shapes](./include/cutlass/cluster_launch.hpp).
- Data types including NVFP4, MXFP4, MXFP6, and MXFP8 and all their supported element and scale factor types.
- Tile schedulers using [Blackwell's Cluster Launch Control (CLC) feature](./media/docs/blackwell_cluster_launch_control.md) to implement dynamic persistence scheduling for [GEMMs](./include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp), and [stream-K](./include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp).
- Extensions to testbeds and reference check code for unit tests and CUTLASS profiler.
-* Full support for Blackwell kernels in CUTLASS 3.x API:
+* Full support for Blackwell SM100 kernels in CUTLASS 3.x API:
- [Blackwell specific kernel layers](./include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp) that
- + Implement a new warp-specialization recipe tuned specifically for Blackwell.
+ + Implement a new warp-specialization recipe tuned specifically for Blackwell SM100 architecture.
+ Leverage all the new features such as CLC based tile scheduling, preferred cluster, and TMEM based double buffering of accumulators.
+ Support stream-K load balancing for all kernel types everywhere via composable scheduler support.
- Blackwell collective mainloops that target the TCGen05 MMA instructions (both SS and TS) for
@@ -73,7 +73,10 @@ For a background on Blackwell's new features, please consult the PTX documentati
* CUTLASS library and profiler integration for block scaled data types for kernel emission, profiling, and verification.
- Support for preferred and fallback cluster shapes via profiler command line arguments parsing to set dynamic cluster shapes.
- Support for dynamic datatypes by parsing profiler via profiler command line arguments parsing to set dynamic datatype setting in TCGen05 MMA instruction descriptors.
-* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell
+ - Support for mixed input GEMM kernels on Hopper in the profiler.
+* New CUTLASS profiler flag `use-cuda-graphs` to reduce overheads when benchmarking launch-bound kernels.
+* A new 3.x version of grouped GEMM to the CUTLASS library and generates kernels for Hopper and Blackwell. Now grouped GEMM support is enabled in the CUTLASS profiler (`./cutlass_profiler --operation=GroupedGemm --help` for details).
+* Set of examples that demonstrate the usage of the 3.x API for targeting Blackwell SM100 architecture:
- [Basic FP16 and FP8 GEMMs with minimal changes from Hopper examples](./examples/70_blackwell_gemm/), demonstrating ease of migration for off the shelf kernels using the 3.x collective builder API.
- GEMM with [opt-in collective builder schedules showcasing available recipes](./examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu) for Blackwell.
- Block scaled data type GEMMs targeting Blackwell's native block scaled Tensor Cores:
@@ -85,6 +88,10 @@ For a background on Blackwell's new features, please consult the PTX documentati
- Grouped GEMM for [vanilla FP8 data inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu) and [NVFP4 block scaled inputs](./examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu).
- Convolution kernels for [fprop](./examples/76_blackwell_conv/76_blackwell_conv_fprop.cu), [dgrad](./examples/76_blackwell_conv/76_blackwell_conv_dgrad.cu), and [wgrad](./examples/76_blackwell_conv/76_blackwell_conv_wgrad.cu).
- [Fused multi-head attention fprop kernel](./examples/77_blackwell_fmha/77_blackwell_fmha.cu) supporting fp16/bf16/fp8 data types across head dims of 32,64, and 128.
+ - A new BF16x9 GEMM [kernel](./examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu) that emulates FP32 GEMM (SGEMM) using BF16 operations.
+* Set of examples that demonstrate the usage of the 3.x API for targeting Hopper architecture:
+ - A set of new [Hopper grouped GEMM kernels](./examples/69_hopper_mixed_dtype_grouped_gemm/) that support mixed A and B datatypes.
+ - A new [Hopper FP8 GEMM with groupwise scaling](./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu).
* Documentation updates:
- [Quickstart - instantiating a Blackwell block-scaled GEMM](./media/docs/quickstart.md#instantiating-a-blackwell-gemm-kernel).
- Detailed [Blackwell block-scaled GEMM functionality documentation](./media/docs/blackwell_functionality.md)
diff --git a/customConfigs.cmake b/customConfigs.cmake
index c86e15be..e39212db 100644
--- a/customConfigs.cmake
+++ b/customConfigs.cmake
@@ -47,7 +47,7 @@ function(cutlass_generate_kernel_filter_and_testlists_files)
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR}
${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py
--generator-target=${__TEST_SET_NAME}
- --cuda-version=${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}
+ --cuda-version=${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR}
--architectures=${CUTLASS_NVCC_ARCHS}
--kernels=\*
--disable-cutlass-package-imports
diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
index 0c407d34..74aa3614 100644
--- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
+++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
@@ -45,7 +45,7 @@
CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can
improve performance.
Examples:
- $ ./examples/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling/64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \
+ $ ./examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling \
--m=2816 --n=3072 --k=16384 \
--save_aux=false --save_amax=false \
--device_scale=false --raster=h --swizzle=2
@@ -119,22 +119,56 @@ using ElementBias = float;
using ElementAccumulator = float; // Element type for internal accumulation
using ElementBlockScale = float; // Element type for blockscaling during accumulation
using ElementCompute = float; // Element type for epilogue computation
-using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
-using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
-using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
-using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
-constexpr int ScaleMsPerTile = 2;
-constexpr int ScaleGranularityM = size<0>(TileShape{}) / ScaleMsPerTile;
+using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()...
-using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
-using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
+// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor
+// Given TileShape = Shape<_128,_128,_128>:
+// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor)
+// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling
+// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling
+// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling
+template
+struct GroupScaleConfig {
+ using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
+ using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
+ using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
+ using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
-using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
-using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
+ static constexpr int ScaleGranularityM = ScaleGranularityM_;
+ static constexpr int ScaleGranularityN = ScaleGranularityN_;
+ static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
+ static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN;
+
+ static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
+ "FP8 scaling granularity must evenly divide tile shape along M.");
+ static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile,
+ "FP8 scaling granularity must evenly divide tile shape along N.");
+
+ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
+ using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
+ using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
+ using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;
+};
-using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
+using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>;
+using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>;
+using GroupScale2D1DConfig = GroupScaleConfig(TileShape_{}), 1>;
+using GroupScale2D2DConfig = GroupScaleConfig(TileShape_{}), size<1>(TileShape_{})>;
+
+template
+struct GroupScaleGemm {
+ using ArchTag = typename ScheduleConfig::ArchTag;
+ using OperatorClass = typename ScheduleConfig::OperatorClass;
+ using TileShape = typename ScheduleConfig::TileShape;
+ using ClusterShape = typename ScheduleConfig::ClusterShape;
+ using KernelSchedule = typename ScheduleConfig::KernelSchedule;
+ using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
+ using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
+ using FusionOperation = typename ScheduleConfig::FusionOperation;
+
+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
@@ -145,7 +179,7 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
FusionOperation
>::CollectiveOp;
-using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
+ using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
@@ -157,24 +191,38 @@ using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collectiv
KernelSchedule
>::CollectiveOp;
-using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
- Shape, // Indicates ProblemShape
- CollectiveMainloopWithBlockWiseScaling,
- CollectiveEpilogue
->;
+ using GemmKernelDefault = cutlass::gemm::kernel::GemmUniversal<
+ Shape,
+ CollectiveMainloopWithGroupWiseScaling,
+ CollectiveEpilogue
+ >;
-using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+ using GemmKernelStreamK = cutlass::gemm::kernel::GemmUniversal<
+ Shape,
+ CollectiveMainloopWithGroupWiseScaling,
+ CollectiveEpilogue,
+ cutlass::gemm::StreamKScheduler
+ >;
+
+ using GemmDefault = cutlass::gemm::device::GemmUniversalAdapter;
+ using GemmStreamK = cutlass::gemm::device::GemmUniversalAdapter;
+};
+
+using GroupScale1D1DGemm = GroupScaleGemm;
+using GroupScale1D2DGemm = GroupScaleGemm;
+using GroupScale2D1DGemm = GroupScaleGemm;
+using GroupScale2D2DGemm = GroupScaleGemm;
// Extract information from Gemm kernel.
-using EpilogueOutputOp = typename Gemm::EpilogueOutputOp;
+using EpilogueOutputOp = typename GroupScale1D1DGemm::GemmDefault::EpilogueOutputOp;
using ElementScalar = typename EpilogueOutputOp::ElementScalar;
using ElementAmax = typename EpilogueOutputOp::ElementAmax;
using ActivationFunctor = typename EpilogueOutputOp::ActivationFn;
-using StrideA = typename Gemm::GemmKernel::StrideA;
-using StrideB = typename Gemm::GemmKernel::StrideB;
-using StrideC = typename Gemm::GemmKernel::StrideC;
-using StrideD = typename Gemm::GemmKernel::StrideD;
+using StrideA = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideA;
+using StrideB = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideB;
+using StrideC = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideC;
+using StrideD = typename GroupScale1D1DGemm::GemmDefault::GemmKernel::StrideD;
using StrideAux = StrideD;
constexpr bool IsDFp8 =
@@ -185,9 +233,6 @@ constexpr bool IsAuxFp8 =
cute::is_same_v or
cute::is_same_v;
-static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile,
- "FP8 scaling granularity must evenly divide tile shape along M.");
-
static_assert(cute::is_same_v,
"ElementAccumulator and ElementBlockScale should be same datatype");
@@ -347,13 +392,18 @@ struct Result
}
/// Initialize operands to be used in the GEMM and reference GEMM
+template
void initialize(const Options &options) {
+ using TileShape = typename GroupScaleConfig::TileShape;
+ const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
+ const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
+
// Find Group Scaling tensor shapes based on `ScaleGranularityM`, problem shape, and TileShape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access.
- auto blockscale_n = cute::get<1>(blockscale_shape);
+ auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access.
auto blockscale_k = cute::get<2>(blockscale_shape);
stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l));
@@ -362,18 +412,16 @@ void initialize(const Options &options) {
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l));
stride_aux = stride_D;
-
-
auto a_coord = cutlass::make_Coord(options.m * options.l, options.k);
auto c_coord = cutlass::make_Coord(options.m * options.l, options.n);
auto b_coord = cutlass::make_Coord(options.k, options.n * options.l);
auto groupscale_a_coord = cutlass::make_Coord(groupscale_m * options.l, blockscale_k);
- auto blockscale_b_coord = cutlass::make_Coord(blockscale_k, blockscale_n * options.l);
+ auto groupscale_b_coord = cutlass::make_Coord(groupscale_n * options.l, blockscale_k);
tensor_A.resize(a_coord);
- blockscale_tensor_A.resize(groupscale_a_coord);
tensor_B.resize(b_coord);
- blockscale_tensor_B.resize(blockscale_b_coord);
+ blockscale_tensor_A.resize(groupscale_a_coord);
+ blockscale_tensor_B.resize(groupscale_b_coord);
tensor_C.resize(c_coord);
tensor_D.resize(c_coord);
tensor_ref_D.resize(c_coord);
@@ -393,7 +441,7 @@ void initialize(const Options &options) {
#if 0 // Dump blockscaled tensors
std::cout << "blockscale_tensor_A: " << groupscale_a_coord << std::endl;
std::cout << blockscale_tensor_A.host_view() << "\n";
- std::cout << "blockscale_tensor_B: " << blockscale_b_coord << std::endl;
+ std::cout << "blockscale_tensor_B: " << groupscale_b_coord << std::endl;
std::cout << blockscale_tensor_B.host_view() << "\n";
#endif
@@ -441,21 +489,26 @@ void initialize(const Options &options) {
if (IsDFp8 && options.save_amax) {
abs_max_D.resize(cutlass::make_Coord(1));
+ initialize_tensor(abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
abs_max_D.sync_device();
reference_abs_max_D.resize(cutlass::make_Coord(1));
+ initialize_tensor(reference_abs_max_D.host_view(), cutlass::Distribution::AllZeros, 0);
}
if (IsAuxFp8 && options.save_aux && options.save_amax) {
abs_max_aux.resize(cutlass::make_Coord(1));
+ initialize_tensor(abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
abs_max_aux.sync_device();
reference_abs_max_aux.resize(cutlass::make_Coord(1));
+ initialize_tensor(reference_abs_max_aux.host_view(), cutlass::Distribution::AllZeros, 0);
}
}
/// Populates a Gemm::Arguments structure from the given commandline options
-typename Gemm::Arguments args_from_options(const Options &options)
+template
+GemmArguments args_from_options(const Options &options)
{
- typename Gemm::Arguments arguments{
+ GemmArguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{options.m, options.n, options.k, options.l},
{tensor_A.device_data(),
@@ -513,14 +566,15 @@ typename Gemm::Arguments args_from_options(const Options &op
return arguments;
}
-bool verify(const Options &options) {
+/// Don't know why the compiler does not like verify() being templated...
+bool verify(const Options &options, const int ScaleMsPerTile, const int ScaleNsPerTile) {
//
// Compute reference output
//
// Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape
auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k);
- auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape{})));
+ auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{})));
auto blockscale_m = cute::get<0>(blockscale_shape);
auto blockscale_n = cute::get<1>(blockscale_shape);
auto blockscale_k = cute::get<2>(blockscale_shape);
@@ -565,8 +619,8 @@ bool verify(const Options &options) {
);
auto blockscale_B = cute::make_tensor(blockscale_tensor_B.host_data(),
cute::make_layout(
- cute::make_shape(blockscale_n, blockscale_k, options.l),
- cute::make_stride(blockscale_k, 1, blockscale_n * blockscale_k)
+ cute::make_shape(blockscale_n, ScaleNsPerTile, blockscale_k, options.l),
+ cute::make_stride(blockscale_k * ScaleNsPerTile, 1, ScaleNsPerTile, blockscale_n * blockscale_k * ScaleNsPerTile)
)
);
@@ -575,7 +629,7 @@ bool verify(const Options &options) {
cutlass::reference::host::GettMainloopParams mainloop_params{
+ TileShape_> mainloop_params{
A, B, // Operand Tensors
blockscale_A, blockscale_B // Groupwise scaling Tensors
};
@@ -641,16 +695,22 @@ bool verify(const Options &options) {
}
/// Execute a given example GEMM computation
-template
+template
int run(Options &options)
{
- initialize(options);
+ using TileShape = typename GroupScaleConfig::TileShape;
+ const int ScaleGranularityM = GroupScaleConfig::ScaleGranularityM;
+ const int ScaleGranularityN = GroupScaleConfig::ScaleGranularityN;
+ const int ScaleMsPerTile = GroupScaleConfig::ScaleMsPerTile;
+ const int ScaleNsPerTile = GroupScaleConfig::ScaleNsPerTile;
+
+ initialize(options);
// Instantiate CUTLASS kernel depending on templates
Gemm gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm
- auto arguments = args_from_options(options);
+ auto arguments = args_from_options(options);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = Gemm::get_workspace_size(arguments);
@@ -669,7 +729,7 @@ int run(Options &options)
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
- result.passed = verify(options);
+ result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
@@ -683,6 +743,7 @@ int run(Options &options)
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
+ CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
@@ -702,9 +763,13 @@ int run(Options &options)
}
std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
+ std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl;
+ std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl;
+ std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl;
std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
+ fflush(stdout);
}
return 0;
@@ -753,7 +818,27 @@ int main(int argc, char const **args) {
//
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
- run(options);
+ std::cout << "Basic split-K GEMM kernel" << std::endl;
+ run(options);
+ std::cout << std::endl;
+ run(options);
+ std::cout << std::endl;
+ run(options);
+ std::cout << std::endl;
+ run(options);
+ std::cout << std::endl;
+
+ std::cout << std::endl;
+
+ std::cout << "StreamK GEMM kernel" << std::endl;
+ run(options);
+ std::cout << std::endl;
+ run(options);
+ std::cout << std::endl;
+ run(options);
+ std::cout << std::endl;
+ run(options);
+ std::cout << std::endl;
#endif
return 0;
diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h
index cb3ff022..e9809f6b 100644
--- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h
+++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h
@@ -220,10 +220,12 @@ void gett_mainloop(
int64_t block_m = m / kBlockM;
int64_t block_n = n / kBlockN;
cute::Tensor blockscale_A = mainloop_params.ScaleA(block_m, _, _, l);
- cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, l);
+ cute::Tensor blockscale_B = mainloop_params.ScaleB(block_n, _, _, l);
const int ScaleGranularityM = cute::size<0>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleA.shape());
- assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
+ const int ScaleGranularityN = cute::size<1>(typename MainloopParams::TileShape{}) / cute::size<1>(mainloop_params.ScaleB.shape());
+ assert(cute::size<0>(typename MainloopParams::TileShape{}) == ScaleGranularityM * cute::size<1>(mainloop_params.ScaleA.shape()));
+ assert(cute::size<1>(typename MainloopParams::TileShape{}) == ScaleGranularityN * cute::size<1>(mainloop_params.ScaleB.shape()));
// Compute on this k-block
for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) {
@@ -231,7 +233,7 @@ void gett_mainloop(
// Load Blockwise scaling factor from blockscale Tensors for B
int64_t block_k = k / kBlockK;
cute::Tensor scale_a = blockscale_A(_, block_k);
- ElementBlockScaleB scale_b = blockscale_B[block_k];
+ cute::Tensor scale_b = blockscale_B(_, block_k);
// Load A
ElementAccumulator a_frag[kBlockM];
@@ -268,8 +270,10 @@ void gett_mainloop(
// (c) Update permanent (accu)
if ((k+1) % kBlockK == 0) {
for (int m_b = 0; m_b < kBlockM; ++m_b) {
+ auto scale_a_m_b = scale_a[m_b / ScaleGranularityM];
for (int n_b = 0; n_b < kBlockN; ++n_b) {
- ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a[m_b / ScaleGranularityM] * scale_b;
+ auto scale_b_n_b = scale_b[n_b / ScaleGranularityN];
+ ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b;
acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b];
acc_temp[m_b][n_b] = ElementAccumulator(0);
}
diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu
index b22d8305..c1978c32 100644
--- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu
+++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm.cu
@@ -32,7 +32,19 @@
/*! \file
\brief
- NOTE: Write docu
+ Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
+ See 55_hopper_int4_bf16_gemm.cu for more details about W4A16 GEMMs with layout shuffling.
+
+ Limitations:
+ 1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported.
+
+ To run this example:
+
+ $ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_bf16_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10
+
+ The above example command makes all 10 groups to be sized at the given m, n, k sizes.
+ Skipping any of the problem dimensions randomizes it across the different groups.
+ Same applies for alpha and beta values that are randomized across the different groups.
*/
#include
diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu
index cc0494ec..07ff66b3 100644
--- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu
+++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm.cu
@@ -32,7 +32,19 @@
/*! \file
\brief
- NOTE: Write docu
+ Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
+ See 55_hopper_int4_fp8_gemm.cu for more details about W4A8 GEMMs with lookup table.
+
+ Limitations:
+ 1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported.
+
+ To run this example:
+
+ $ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_int4_fp8_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10
+
+ The above example command makes all 10 groups to be sized at the given m, n, k sizes.
+ Skipping any of the problem dimensions randomizes it across the different groups.
+ Same applies for alpha and beta values that are randomized across the different groups.
*/
#include
diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu
index 883d8cbf..ffeb233e 100644
--- a/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu
+++ b/examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm.cu
@@ -31,7 +31,19 @@
/*! \file
\brief
- NOTE: Write docu
+ Hopper Mixed-input Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture.
+ See 55_hopper_mixed_dtype_gemm.cu for more details about Mixed-input GEMMs.
+
+ Limitations:
+ 1) Only support row-wise scaling. Zero-points and block-wise scaling is currently not supported.
+
+ To run this example:
+
+ $ ./examples/69_hopper_mixed_dtype_grouped_gemm/69_hopper_mixed_dtype_grouped_gemm --m=2048 --n=2048 --k=2048 --mode=1 --groups=10
+
+ The above example command makes all 10 groups to be sized at the given m, n, k sizes.
+ Skipping any of the problem dimensions randomizes it across the different groups.
+ Same applies for alpha and beta values that are randomized across the different groups.
*/
#include
diff --git a/examples/69_hopper_mixed_dtype_grouped_gemm/README.md b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md
new file mode 100644
index 00000000..272d36e5
--- /dev/null
+++ b/examples/69_hopper_mixed_dtype_grouped_gemm/README.md
@@ -0,0 +1,14 @@
+This example extends Example 55 to support Grouped GEMMs in CUTLASS.
+
+## High level overview
+
+This example shows how to perform Grouped GEMMs on Hopper when A and B have different types. In the Grouped GEMM, multiple GEMMs with potentially different problem shapes can be excetued in a batch. The interface is similar to the standard mixed-input GEMM presented in Example 55, with a few noteworthy differences:
+- inside the collective builder, replace the layout types with layout pointer types.
+- in the arguments, pass the group size, array of the problem sizes, and the array of strides for matrix A and B.
+- if scales and zero-points are included, also pass the array of their strides in the arguments.
+
+Note that in Example 55, the argument `--g` is used to determine the block scale size. It is important not to confuse this with the `--groups` argument in this example, which specifies the number of GEMMs.
+
+## Upcoming features
+
+Currently, the Mixed-input Grouped GEMM only supports row-wise scaling. Please contact us if zero-points or block-wise scaling are needed.
diff --git a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu
index 39123cac..3cee6caf 100644
--- a/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu
+++ b/examples/70_blackwell_gemm/70_blackwell_fp16_gemm.cu
@@ -115,15 +115,11 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
using MmaTileShape_MNK = Shape<_256,_128,_64>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_2,_1>;
-// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
-using AtomThrShape_MNK = Shape<_2, _1, _1>;
-// Shape of the tile computed by each SM
-using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- PerSmTileShape_MNK, ClusterShape_MNK,
+ MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
diff --git a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu
index 0b1758b9..69a36310 100644
--- a/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu
+++ b/examples/70_blackwell_gemm/70_blackwell_fp8_gemm.cu
@@ -131,17 +131,13 @@ using ElementAmax = float;
using MmaTileShape_MNK = Shape<_256,_128,_64>;
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = Shape<_2,_2,_1>;
-// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
-using AtomThrShape_MNK = Shape<_2, _1, _1>;
-// Shape of the tile computed by each SM
-using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
using FusionOp = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
LayoutC, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
- PerSmTileShape_MNK, ClusterShape_MNK,
+ MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
diff --git a/examples/70_blackwell_gemm/CMakeLists.txt b/examples/70_blackwell_gemm/CMakeLists.txt
index d88a8c56..cb401e3a 100644
--- a/examples/70_blackwell_gemm/CMakeLists.txt
+++ b/examples/70_blackwell_gemm/CMakeLists.txt
@@ -28,7 +28,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
70_blackwell_fp16_gemm
70_blackwell_fp16_gemm.cu
diff --git a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu
index 6712d7a9..427af254 100644
--- a/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu
+++ b/examples/71_blackwell_gemm_with_collective_builder/71_blackwell_gemm_with_collective_builder.cu
@@ -184,12 +184,8 @@ struct ExampleRunner {
std::is_same_v ||
// Auto schedule will try to select 2sm cluster MMA based on cluster M
std::is_same_v && size<0>(ClusterShapeMNK{}) % 2 == 0;
- // The MNK layout of CTAs within a cluster MMA
- using AtomThrMNK = std::conditional_t, Shape<_1,_1,_1>>;
// The MMA tile used by the mainloop collective. Blackwell 1sm MMA supports up to MMA tile M = 128, 2sm MMA supports up to MMA tile M = 256
using MmaTileMNK = std::conditional_t, Shape<_128,_128,_64>>;
- // The Output tile used by the epilogue collective
- using OutputTileMNK = decltype(shape_div(MmaTileMNK{}, AtomThrMNK{}));
// 16B alignment lets us use TMA
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value;
@@ -220,7 +216,7 @@ struct ExampleRunner {
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
- OutputTileMNK, ClusterShapeMNK,
+ MmaTileMNK, ClusterShapeMNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
@@ -503,20 +499,20 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN
print_result("KernelScheduleAuto mainloop schedule with EpilogueScheduleAuto epilogue schedule and 3 mainloop stages", passed);
// 1SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules
- ExampleRunner runner_2;
+ ExampleRunner runner_2;
passed = runner_2.run(options, hw_info);
- print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed);
+ print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized1Sm epilogue schedule", passed);
// 1SM cluster MMA mainloop schedules can also be used with 1SM TMA epilogue schedules
// 1SM cluster MMA mainloop schedules will not work with 2SM TMA epilogue schedules
ExampleRunner runner_3;
passed = runner_3.run(options, hw_info);
- print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed);
+ print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with TmaWarpSpecialized1Sm epilogue schedule", passed);
// 2SM cluster MMA mainloop schedules can be used with direct store ("no-smem") epilogue schedules
- ExampleRunner runner_4;
+ ExampleRunner runner_4;
passed = runner_4.run(options, hw_info);
- print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue schedule", passed);
+ print_result("KernelTmaWarpSpecialized2SmSm100 mainloop schedule with NoSmemWarpSpecialized2Sm epilogue schedule", passed);
// 2SM cluster MMA mainloop schedules can also be used with 2SM TMA epilogue schedules
// 2SM cluster MMA mainloop schedules will not work with SM TMA epilogue schedules
@@ -556,11 +552,11 @@ if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MIN
// Blackwell direct store epilogue schedule supports custom EVTs and named fusion operations as well (not supported for pre-Blackwell kernels)
ExampleRunner<
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100,
- cutlass::epilogue::NoSmemWarpSpecialized,
+ cutlass::epilogue::NoSmemWarpSpecialized1Sm,
cutlass::gemm::collective::StageCountAuto,
UseCustomEVT> runner_9;
passed = runner_9.run(options, hw_info);
- print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized epilogue and custom EVT", passed);
+ print_result("KernelTmaWarpSpecialized1SmSm100 mainloop schedule with NoSmemWarpSpecialized1Sm epilogue and custom EVT", passed);
#endif
diff --git a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt
index 5bac6494..a326f461 100644
--- a/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt
+++ b/examples/71_blackwell_gemm_with_collective_builder/CMakeLists.txt
@@ -27,7 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Both filenames are shorter to avoid MAX_PATH issues on Windows.
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
71_blackwell_gemm_with_collective_builder
71_blackwell_gemm_with_collective_builder.cu
diff --git a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu
index ec597966..f7e12fbf 100644
--- a/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu
+++ b/examples/72_blackwell_narrow_precision_gemm/72a_blackwell_nvfp4_bf16_gemm.cu
@@ -117,11 +117,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
// Kernel Perf config
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
-using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- PerSmTileShape_MNK, ClusterShape,
+ MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
diff --git a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu
index cefa3e92..2719cab9 100644
--- a/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu
+++ b/examples/72_blackwell_narrow_precision_gemm/72b_blackwell_nvfp4_nvfp4_gemm.cu
@@ -121,7 +121,6 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
// Kernel Perf config
using MmaTileShape = Shape<_128,_128,_256>; // MMA's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
-using PerSmTileShape_MNK = Shape<_128,_128,_256>; // Threadblock-level tile size
constexpr int InputSFVectorSize = 16;
constexpr int OutputSFVectorSize = InputSFVectorSize;
@@ -137,7 +136,7 @@ using FusionOperation = cutlass::epilogue::fusion::LinCombBlockScaleFactor<
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- PerSmTileShape_MNK, ClusterShape,
+ MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
diff --git a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu
index b73f2c94..2784d050 100644
--- a/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu
+++ b/examples/72_blackwell_narrow_precision_gemm/72c_blackwell_mixed_mxfp8_bf16_gemm.cu
@@ -118,11 +118,10 @@ using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // O
// Kernel Perf config
using MmaTileShape = Shape<_256,_256,_256>; // MMA's tile size
using ClusterShape = Shape<_4,_4,_1>; // Shape of the threadblocks in a cluster
-using PerSmTileShape_MNK = Shape<_128,_256,_256>; // Threadblock-level tile size
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- PerSmTileShape_MNK, ClusterShape,
+ MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
diff --git a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt
index fa80c184..eaeb6600 100644
--- a/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt
+++ b/examples/72_blackwell_narrow_precision_gemm/CMakeLists.txt
@@ -28,7 +28,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
72a_blackwell_nvfp4_bf16_gemm
72a_blackwell_nvfp4_bf16_gemm.cu
diff --git a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt
index 0d0f7757..a4a18324 100644
--- a/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt
+++ b/examples/73_blackwell_gemm_preferred_cluster/CMakeLists.txt
@@ -28,7 +28,7 @@
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
73_blackwell_gemm_preferred_cluster
blackwell_gemm_preferred_cluster.cu
diff --git a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu
index fb62e844..19c4efd1 100644
--- a/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu
+++ b/examples/73_blackwell_gemm_preferred_cluster/blackwell_gemm_preferred_cluster.cu
@@ -129,27 +129,22 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0
using MmaTileShape_MNK = Shape<_256,_128,_64>;
-// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
-using AtomThrShape_MNK = Shape<_2, _1, _1>;
-// Shape of the tile computed by each SM
-using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
// Shape of the cluster set to to indicate dynamic cluster shape
using ClusterShape_MNK = Shape;
// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that
// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2
-// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1>
-using KernelSchedule = cute::conditional_t;
+// To use tcgen05 MMA cta_group = 2, users must explicitly use 2sm builder schedules
+using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100;
+using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- PerSmTileShape_MNK, ClusterShape_MNK,
+ MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
- cutlass::epilogue::collective::EpilogueScheduleAuto
+ EpilogueSchedule
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
diff --git a/examples/74_blackwell_gemm_streamk/CMakeLists.txt b/examples/74_blackwell_gemm_streamk/CMakeLists.txt
index 618561f5..5a378241 100644
--- a/examples/74_blackwell_gemm_streamk/CMakeLists.txt
+++ b/examples/74_blackwell_gemm_streamk/CMakeLists.txt
@@ -29,7 +29,7 @@
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
74_blackwell_gemm_streamk
blackwell_gemm_streamk.cu
diff --git a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu
index bb99fa4a..8f6def99 100644
--- a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu
+++ b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu
@@ -133,22 +133,17 @@ using OperatorClass = cutlass::arch::OpClassTensorOp; // O
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape % 2 == 0
using MmaTileShape_MNK = Shape<_256,_128,_64>;
-// Shape of the threadblocks participating in a tcgen05 MMA. <1, 1, 1> for cta_group = 1, <2, 1, 1> for cta_group = 2
-using AtomThrShape_MNK = Shape<_2, _1, _1>;
-// Shape of the tile computed by each SM
-using PerSmTileShape_MNK = decltype(shape_div(MmaTileShape_MNK{}, AtomThrShape_MNK{}));
// Shape of the cluster set to to indicate dynamic cluster shape
using ClusterShape_MNK = Shape;
-// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that
+// When dynamic cluster is used, KernelScheduleAuto always selects mainloop dispatch policy that
// lowers to tcgen05 MMA cta_group = 1 as we don't know if the dynamic cluster M dimension will be a multiple of 2
-// To use KernelScheduleAuto, users need to set AtomThrShape_MNK to Shape<1, 1, 1>
-using KernelSchedule = cute::conditional_t;
+// To use tcgen05 MMA cta_group = 2, users must explicitly use 2sm builder schedules
+using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100;
+using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- PerSmTileShape_MNK, ClusterShape_MNK,
+ MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu
index 520d8cee..1d8db6e2 100644
--- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu
+++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm.cu
@@ -121,32 +121,25 @@ using StageCountType = cutlass::gemm::collective::StageCountAuto; // S
// Runtime Cluster Shape
using ClusterShape = Shape;
-// For Static Cluster Shape:
-// using ClusterShape = Shape<_2,_1,_1>; // for example
-// using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config
-// using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder
-// using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder
// Different configs for 1SM and 2SM MMA kernel
struct MMA1SMConfig {
using MmaTileShape = Shape<_128,_256,Int<128 / sizeof(ElementA)>>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
- using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{}));
};
struct MMA2SMConfig {
using MmaTileShape = Shape<_256,_256,Int<128 / sizeof(ElementA)>>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
- using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{}));
};
template
struct GivenGemmSchedule {
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- typename ScheduleConfig::OutputTileShape, ClusterShape,
+ typename ScheduleConfig::MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
diff --git a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu
index fa65e508..ee697135 100644
--- a/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu
+++ b/examples/75_blackwell_grouped_gemm/75_blackwell_grouped_gemm_block_scaled.cu
@@ -143,31 +143,23 @@ using StageCountType = cutlass::gemm::collective::StageCountAuto; // S
// Runtime Cluster Shape
using ClusterShape = Shape;
-/* // For Static Cluster Shape:
-use ClusterShape = Shape<_2,_1,_1> for example
-using AtomThrShape = decltype(shape_div(ClusterShape{}, Shape<_2,_1,_1>{})); // for 2SM config
-using OutputTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // for epilogue builder
-using MmaTileShape = decltype(shape_div(ClusterTileShape{}, AtomThrShape{})); // for mainloop builder
-*/
// Different configs for 1SM and 2SM MMA kernel
struct MMA1SMConfig {
using MmaTileShape = Shape<_128,_256,_256>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
- using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_1,_1,_1>{}));
};
struct MMA2SMConfig {
using MmaTileShape = Shape<_256,_256,_256>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
- using OutputTileShape = decltype(shape_div(MmaTileShape{}, Shape<_2,_1,_1>{}));
};
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass,
- typename MMA1SMConfig::OutputTileShape, ClusterShape,
+ typename MMA1SMConfig::MmaTileShape, ClusterShape,
Shape<_128,_64>,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
@@ -195,7 +187,7 @@ using Gemm = Gemm1SM;
using CollectiveEpilogue2SM = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, EpilogueOperatorClass,
- typename MMA2SMConfig::OutputTileShape, ClusterShape,
+ typename MMA2SMConfig::MmaTileShape, ClusterShape,
Shape<_128,_64>,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC *, AlignmentC,
diff --git a/examples/75_blackwell_grouped_gemm/CMakeLists.txt b/examples/75_blackwell_grouped_gemm/CMakeLists.txt
index 2da2d4c4..0ce48662 100644
--- a/examples/75_blackwell_grouped_gemm/CMakeLists.txt
+++ b/examples/75_blackwell_grouped_gemm/CMakeLists.txt
@@ -49,7 +49,7 @@ set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=50 --iterations=0)
set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes
set(TEST_RANDOM_PERF_LARGE_GROUP --groups=50 --iterations=10) # Random problem sizes
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
75_blackwell_grouped_gemm
75_blackwell_grouped_gemm.cu
diff --git a/examples/76_blackwell_conv/CMakeLists.txt b/examples/76_blackwell_conv/CMakeLists.txt
index 8d31d743..e4042aa6 100644
--- a/examples/76_blackwell_conv/CMakeLists.txt
+++ b/examples/76_blackwell_conv/CMakeLists.txt
@@ -28,7 +28,7 @@
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
76_blackwell_conv_fprop
76_blackwell_conv_fprop.cu
diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt
index c840d8ba..90b47387 100644
--- a/examples/77_blackwell_fmha/CMakeLists.txt
+++ b/examples/77_blackwell_fmha/CMakeLists.txt
@@ -49,7 +49,7 @@ set(TEST_GEN_REMAP --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --remap)
set(TEST_GEN_CACHEONLY --b=2 --h=4 --h_k=2 --k=512 --d=128 --verify --cache-only)
if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")))
- if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+ if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
77_blackwell_fmha_fp8
77_blackwell_fmha.cu
diff --git a/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu
index 48e1da6c..f50e85b4 100644
--- a/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu
+++ b/examples/78_blackwell_emulated_bf16x9_gemm/78_blackwell_emulated_bf16x9_gemm.cu
@@ -106,25 +106,23 @@ using ArchTag = cutlass::arch::Sm100; // T
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
// Kernel Perf config
-using ClusterTileShape = Shape<_256,_128,_16>; // Cluster-level tile shape
-using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
-using CtaTileShape = decltype(shape_div(ClusterTileShape{}, ClusterShape{})); // Threadblock-level tile shape
-using MmaTileShape = Shape<_256,_128,_16>; // Mma instruction shape
+using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster
+using MmaTileShape = Shape<_256,_128,_16>; // Mma instruction shape
// Build the epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
- CtaTileShape, ClusterShape,
+ MmaTileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutC, AlignmentC,
ElementC, LayoutC, AlignmentC,
- cutlass::epilogue::NoSmemWarpSpecialized
+ cutlass::epilogue::NoSmemWarpSpecialized2Sm
>::CollectiveOp;
// Build the mainloop
// Note: Emulated BF16x9 kernels need to manually specify a mainloop schedule and cannot use KernelScheduleAuto
-using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecializedFastFP32SmemSm100;
+using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmFastFP32SmemSm100;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
diff --git a/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt b/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt
index 1b36a4fd..6fcbd062 100644
--- a/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt
+++ b/examples/78_blackwell_emulated_bf16x9_gemm/CMakeLists.txt
@@ -28,7 +28,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-if(NOT CUTLASS_NVCC_ARCHS STREQUAL "100")
+if (CUTLASS_NVCC_ARCHS MATCHES 100a)
cutlass_example_add_executable(
78_blackwell_emulated_bf16x9_gemm
78_blackwell_emulated_bf16x9_gemm.cu
diff --git a/examples/README.md b/examples/README.md
index ec39bf22..68bf7077 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -254,7 +254,7 @@
Blackwell SM100 GEMM example demonstrating compatible mainloop+epilogue builder schedules and epilogue visitor tree (EVT) construction
-* [72a_blackwell_narrow_precision_gemm](72a_blackwell_narrow_precision_gemm)
+* [72_blackwell_narrow_precision_gemm](72_blackwell_narrow_precision_gemm/)
Block-scaled dense GEMM example targeting the NVIDIA Blackwell SM100 Tensor Core MMA using CUTLASS 3.x APIs.
@@ -278,6 +278,10 @@
Blackwell SM100 FMHA kernel
+* [78_blackwell_emulated_bf16x9_gemm](78_blackwell_emulated_bf16x9_gemm)
+
+ Blackwell SM100 FastFP32 (using BF16 to emulate SGEMM) kernel
+
# CuTe - Programming Examples
Examples that do not rely on CUTLASS and directly showcase the features of CuTe are located in [cutlass/examples/cute](./cute/).
diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp
index e1950b92..b97fc4c8 100644
--- a/include/cute/arch/config.hpp
+++ b/include/cute/arch/config.hpp
@@ -86,5 +86,3 @@
#define CUTE_ARCH_FLOAT2_MATH_ENABLED
#endif
-
-
diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp
index a157008c..f5f50647 100644
--- a/include/cute/arch/copy_sm90_desc.hpp
+++ b/include/cute/arch/copy_sm90_desc.hpp
@@ -208,6 +208,7 @@ to_CUtensorMapDataType() {
if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
+ if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else
if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else
if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else
if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else
diff --git a/include/cute/arch/mma_sm100_umma.hpp b/include/cute/arch/mma_sm100_umma.hpp
index d954544f..1f74223b 100644
--- a/include/cute/arch/mma_sm100_umma.hpp
+++ b/include/cute/arch/mma_sm100_umma.hpp
@@ -956,7 +956,7 @@ template
struct SM100_MMA_MXF4_SS
{
- static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster OMMA.");
+ static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster MMA.");
static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4_SS N-mode size should be a multiple of 8 between 8 and 256.");
static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_SS Vector size can only be 16 or 32.");
diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp
index cd344fd5..6a767ae3 100644
--- a/include/cute/atom/copy_traits_sm100.hpp
+++ b/include/cute/atom/copy_traits_sm100.hpp
@@ -45,7 +45,6 @@
namespace cute
{
-
template <>
struct Copy_Traits
{
diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp
index ea3eaf72..a431fc4a 100644
--- a/include/cute/container/array.hpp
+++ b/include/cute/container/array.hpp
@@ -372,7 +372,7 @@ void swap(array& a, array& b)
/// @return A cute::array of the elements of @c t in reverse order.
template
CUTE_HOST_DEVICE constexpr
-cute::array reverse(cute::array const& t)
+cute::array reverse(cute::array const& t)
{
if constexpr (N == 0u) {
return t;
@@ -441,17 +441,6 @@ struct tuple_element>
using type = T;
};
-template
-struct tuple_size const>
- : CUTE_STL_NAMESPACE::integral_constant
-{};
-
-template
-struct tuple_element const>
-{
- using type = T;
-};
-
} // end namespace CUTE_STL_NAMESPACE
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
@@ -477,16 +466,5 @@ struct tuple_element>
using type = T;
};
-template
-struct tuple_size const>
- : CUTE_STL_NAMESPACE::integral_constant
-{};
-
-template
-struct tuple_element const>
-{
- using type = T;
-};
-
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp
index d6b1fafb..38da7ace 100644
--- a/include/cute/container/array_subbyte.hpp
+++ b/include/cute/container/array_subbyte.hpp
@@ -611,17 +611,6 @@ struct tuple_element>
using type = T;
};
-template
-struct tuple_size>
- : CUTE_STL_NAMESPACE::integral_constant
-{};
-
-template
-struct tuple_element>
-{
- using type = T;
-};
-
} // end namespace CUTE_STL_NAMESPACE
#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
@@ -647,16 +636,5 @@ struct tuple_element>
using type = T;
};
-template
-struct tuple_size>
- : CUTE_STL_NAMESPACE::integral_constant
-{};
-
-template
-struct tuple_element>
-{
- using type = T;
-};
-
} // end namespace std
#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
diff --git a/include/cute/container/packed_tuple.hpp b/include/cute/container/packed_tuple.hpp
deleted file mode 100644
index a7a1c3b2..00000000
--- a/include/cute/container/packed_tuple.hpp
+++ /dev/null
@@ -1,254 +0,0 @@
-/***************************************************************************************************
- * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
- * SPDX-License-Identifier: BSD-3-Clause
- *
- * Redistribution and use in source and binary forms, with or without
- * modification, are permitted provided that the following conditions are met:
- *
- * 1. Redistributions of source code must retain the above copyright notice, this
- * list of conditions and the following disclaimer.
- *
- * 2. Redistributions in binary form must reproduce the above copyright notice,
- * this list of conditions and the following disclaimer in the documentation
- * and/or other materials provided with the distribution.
- *
- * 3. Neither the name of the copyright holder nor the names of its
- * contributors may be used to endorse or promote products derived from
- * this software without specific prior written permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
- * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- **************************************************************************************************/
-#pragma once
-
-#include
-#include
-#include
-#include
-
-namespace cute {
-
-namespace detail {
-
-// Empty Structure Optimization
-template
-struct ESO;
-
-template
-static constexpr bool is_first_empty_v = cute::is_empty::value;
-template
-static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...);
-
-template
-using ESO_t = ESO, is_rest_empty_v, T...>;
-
-// Empty First and Empty Rest...
-template
-struct ESO {
- CUTE_HOST_DEVICE constexpr
- ESO() {}
-
- CUTE_HOST_DEVICE constexpr
- ESO(First const&, Rest const&...) {}
-};
-
-// NonEmpty First and Empty Rest...
-template
-struct ESO {
- CUTE_HOST_DEVICE constexpr
- ESO() : first_{} {}
-
- CUTE_HOST_DEVICE constexpr
- ESO(First const& first, Rest const&...) : first_{first} {}
-
- First first_;
-};
-
-// Empty First and NonEmpty Rest...
-template
-struct ESO {
- CUTE_HOST_DEVICE constexpr
- ESO() : rest_{} {}
-
- CUTE_HOST_DEVICE constexpr
- ESO(First const&, Rest const&... rest) : rest_{rest...} {}
-
- ESO_t rest_;
-};
-
-// NonEmpty T and NonEmpty Rest...
-template
-struct ESO {
- CUTE_HOST_DEVICE constexpr
- ESO() : first_{}, rest_{} {}
-
- CUTE_HOST_DEVICE constexpr
- ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {}
-
- First first_;
- ESO_t rest_;
-};
-
-// Get Nth value from ESO
-template
-CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO const& s) {
- if constexpr (N == 0) {
- if constexpr (F) { return T{}; }
- else { return static_cast(s.first_); }
- } else {
- if constexpr (R) { return cute::tuple_element_t>{}; }
- else { return getv(s.rest_); }
- }
-}
-
-template
-CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO& s) {
- if constexpr (N == 0) {
- if constexpr (F) { return T{}; }
- else { return static_cast(s.first_); }
- } else {
- if constexpr (R) { return cute::tuple_element_t>{}; }
- else { return getv(s.rest_); }
- }
-}
-
-template
-CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO&& s) {
- if constexpr (N == 0) {
- if constexpr (F) { return T{}; }
- else { return static_cast(s.first_); }
- } else {
- if constexpr (R) { return cute::tuple_element_t>{}; }
- else { return getv(static_cast&&>(s.rest_)); }
- }
-}
-
-// findt: Implementation detail of cute::find.
-// If X is the first template argument of the tuple, findt returns C.
-
-template
-CUTE_HOST_DEVICE constexpr
-auto
-findt(ESO const& t) noexcept
-{
- if constexpr (cute::is_same_v) {
- return C{};
- }
- else {
- static_assert(sizeof...(Rest) != 0,
- "The type does not appear in the argument list of the tuple.");
- if constexpr (IsRestEmpty) {
- // The rest is empty, so creating an instance of it is cheap.
- return cute::detail::findt(ESO_t{});
- }
- else {
- return cute::detail::findt(t.rest_);
- }
- }
-}
-
-} // end namespace detail
-
-// packed_tuple is a tuple type that is a standard-layout type
-// whenever all of its template arguments are standard layout types:
-// (cute::is_standard_layout_v && ...) implies (cute::is_standard_layout_v>)
-
-template
-struct packed_tuple : detail::ESO_t
-{
- CUTE_HOST_DEVICE constexpr
- packed_tuple() {}
-
- CUTE_HOST_DEVICE constexpr
- packed_tuple(T const&... ts)
- : detail::ESO_t(ts...)
- {}
-};
-
-template <>
-struct packed_tuple<> {};
-
-template
-CUTE_HOST_DEVICE constexpr
-decltype(auto)
-get(packed_tuple const& t) {
- static_assert(I < sizeof...(T), "Index out of range");
- return detail::getv(t);
-}
-
-template
-CUTE_HOST_DEVICE constexpr
-decltype(auto)
-get(packed_tuple& t) {
- static_assert(I < sizeof...(T), "Index out of range");
- return detail::getv(t);
-}
-
-template
-CUTE_HOST_DEVICE constexpr
-decltype(auto)
-get(packed_tuple&& t) {
- static_assert(I < sizeof...(T), "Index out of range");
- return detail::getv(static_cast&&>(t));
-}
-
-template
-CUTE_HOST_DEVICE constexpr
-packed_tuple
-make_packed_tuple(T const&... t)
-{
- return {t...};
-}
-
-// Returns the position of type X (as a static integer) in the tuple
-// type's argument list. X must be unique in the argument list.
-template
-CUTE_HOST_DEVICE constexpr
-auto
-find(packed_tuple const& t) noexcept
-{
- return detail::findt(t);
-}
-
-} // end namespace cute
-
-namespace CUTE_STL_NAMESPACE
-{
-
-template
-struct tuple_size>
- : CUTE_STL_NAMESPACE::integral_constant
-{};
-
-template
-struct tuple_element>
- : CUTE_STL_NAMESPACE::tuple_element>
-{};
-
-} // end namespace CUTE_STL_NAMESPACE
-
-#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
-namespace std {
-
-template
-struct tuple_size>
- : CUTE_STL_NAMESPACE::integral_constant
-{};
-
-template
-struct tuple_element>
- : CUTE_STL_NAMESPACE::tuple_element>
-{};
-
-} // end namespace std
-#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp
index dab8621e..e62cfe16 100644
--- a/include/cute/container/tuple.hpp
+++ b/include/cute/container/tuple.hpp
@@ -37,169 +37,183 @@
#include
#include
-#if defined(CUTLASS_USE_PACKED_TUPLE)
-# include
-#endif
-
//#include // Advanced optimizations
-// cute::tuple is like std::tuple, with two differences.
+// cute::tuple is like std::tuple, with differences:
//
// 1. It works on both host and device.
// 2. Its template arguments must be semiregular types.
+// 3. It is always a standard-layout type if all of its template arguments are standard-layout types.
+// 4. It is always an empty type if all of its template arguments are empty types.
//
// Semiregular types are default constructible and copyable.
// They include "value types" like int or float,
// but do _not_ include references like int& or float&.
// (See std::tie for an example of a tuple of references.)
//
-// If the template arguments of cute::tuple are all empty types (in
-// the sense of std::is_empty_v), then the cute::tuple is also an
-// empty type. Furthermore, if CUTLASS_USE_PACKED_TUPLE is defined,
-// cute::tuple is always a standard-layout type if all of its template
-// arguments are standard-layout types.
-
-namespace cute
-{
-
-#if defined(CUTLASS_USE_PACKED_TUPLE)
-
-template
-using tuple = packed_tuple;
-
-#else
-
-namespace detail
-{
-
-// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of
+// Standard-layout types preserve ABI across host-device boundaries.
+// They are safe to use as device kernel parameters.
+//
+// The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of
// the conversion SFINAE, special overloading, and avoiding cvref template types.
//
// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x.
-// EBO stands for "empty base optimization."
+namespace cute
+{
+
+namespace detail
+{
+
+// ESO stands for "empty structure optimization."
// We use this technique to ensure that cute::tuple
-// doesn't need to waste space storing any template arguments
-// of cute::tuple that have no data (like integral_constant).
-// Otherwise, cute::tuple would need to spend at least 1 byte
-// for each of its template arguments.
-//
-// This is one way in which cute::tuple differs from std::tuple.
+// doesn't waste space storing template arguments that have no data (like integral_constant).
// Empty types in the template argument list are not even constructed,
-// and do not have unique element addresses. In fact, they are not
-// even members of the tuple or stored in any way. Calling `get`
+// and do not have unique element addresses. Calling `get`
// constructs and returns an instance of an empty type on demand.
-//
-// EBO always "holds" a single value of type T.
-// N is like an array index that TupleBase uses
-// to access the desired tuple element.
-template ::value>
-struct EBO;
-template
-CUTE_HOST_DEVICE constexpr C findt(EBO const&)
-{ return {}; }
+template
+struct ESO;
-// Specialization for types T that have no data;
-// the "static tuple leaf." Valid T here include
-// integral_constant, Int,
-// and any other semiregular type
-// for which std::is_empty_v is true.
-template
-struct EBO
-{
+template
+static constexpr bool is_first_empty_v = cute::is_empty::value;
+template
+static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...);
+
+template
+using ESO_t = ESO, is_rest_empty_v, T...>;
+
+// Empty First and Empty Rest...
+template
+struct ESO {
CUTE_HOST_DEVICE constexpr
- EBO() {}
+ ESO() {}
CUTE_HOST_DEVICE constexpr
- EBO(T const&) {}
+ ESO(First const&, Rest const&...) {}
};
-template
-CUTE_HOST_DEVICE constexpr T getv(EBO const&)
-{ return {}; }
-
-// This is a work around approach to solve a shared memory misalign issue (https://github.com/NVIDIA/cutlass/issues/1250).
-// Will remove this work around implementation once the corresponding fix in compiler is released.
-struct dummy_EBO_base {};
-
-// Specialization for types T that are not empty;
-// the "dynamic tuple leaf." Valid T here include int,
-// any other integral or floating-point type,
-// or any semiregular type for which std::is_empty_v is false.
-template
-struct EBO : private dummy_EBO_base
-{
+// NonEmpty First and Empty Rest...
+template
+struct ESO {
CUTE_HOST_DEVICE constexpr
- EBO() : t_{} {}
+ ESO() : first_{} {}
CUTE_HOST_DEVICE constexpr
- EBO(T const& t) : t_{t} {}
+ ESO(First const& first, Rest const&...) : first_{first} {}
- T t_;
+ First first_;
};
-template
-CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x)
-{ return x.t_; }
-
-template
-CUTE_HOST_DEVICE constexpr T& getv(EBO& x)
-{ return x.t_; }
-
-template
-CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x)
-{ return cute::move(x.t_); }
-
-template
-struct TupleBase;
-
-// Base class of cute::tuple binds each element to an index
-// by inheriting from EBO for each (i, t) in (I..., T...).
-// The storage (for nonempty t) lives in the base classes.
-template
-struct TupleBase, T...>
- : EBO...
-{
+// Empty First and NonEmpty Rest...
+template
+struct ESO {
CUTE_HOST_DEVICE constexpr
- TupleBase() {}
+ ESO() : rest_{} {}
CUTE_HOST_DEVICE constexpr
- TupleBase(T const&... t) : EBO(t)... {}
+ ESO(First const&, Rest const&... rest) : rest_{rest...} {}
+
+ ESO_t rest_;
};
+// NonEmpty T and NonEmpty Rest...
+template
+struct ESO {
+ CUTE_HOST_DEVICE constexpr
+ ESO() : first_{}, rest_{} {}
+
+ CUTE_HOST_DEVICE constexpr
+ ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {}
+
+ First first_;
+ ESO_t rest_;
+};
+
+// Get Nth value from ESO
+template
+CUTE_HOST_DEVICE constexpr
+cute::enable_if_t>>::value,
+ cute::tuple_element_t>>
+getv(ESO const&)
+{
+ return {};
+}
+
+template
+CUTE_HOST_DEVICE constexpr
+cute::enable_if_t>>::value,
+ cute::tuple_element_t> const&>
+getv(ESO const& s)
+{
+ if constexpr (N == 0) {
+ return static_cast(s.first_);
+ } else {
+ return getv(s.rest_);
+ }
+}
+
+template
+CUTE_HOST_DEVICE constexpr
+cute::enable_if_t>>::value,
+ cute::tuple_element_t